Created
March 4, 2025 05:07
-
-
Save RaoHai/495a351d5408b62b7e61ae8d2c23e211 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This notebook shows sql clustering using embedding and vector database" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 1. Breparing Database" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | |
"To disable this warning, you can either:\n", | |
"\t- Avoid using `tokenizers` before the fork if possible\n", | |
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: pandas==2.0.3 in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 1)) (2.0.3)\n", | |
"Requirement already satisfied: pyarrow==14.0.1 in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 2)) (14.0.1)\n", | |
"Requirement already satisfied: fsspec in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 3)) (2024.3.1)\n", | |
"Requirement already satisfied: huggingface_hub in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 4)) (0.28.1)\n", | |
"Requirement already satisfied: transformers in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 5)) (4.49.0)\n", | |
"Requirement already satisfied: torch in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 6)) (2.5.1)\n", | |
"Requirement already satisfied: numpy==1.26.4 in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 7)) (1.26.4)\n", | |
"Requirement already satisfied: tqdm in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 8)) (4.67.1)\n", | |
"Requirement already satisfied: faiss-cpu in ./.venv/lib/python3.9/site-packages (from -r requirements.txt (line 9)) (1.10.0)\n", | |
"Collecting sentence_transformers (from -r requirements.txt (line 10))\n", | |
" Downloading sentence_transformers-3.4.1-py3-none-any.whl.metadata (10 kB)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.2 in ./.venv/lib/python3.9/site-packages (from pandas==2.0.3->-r requirements.txt (line 1)) (2.9.0.post0)\n", | |
"Requirement already satisfied: pytz>=2020.1 in ./.venv/lib/python3.9/site-packages (from pandas==2.0.3->-r requirements.txt (line 1)) (2025.1)\n", | |
"Requirement already satisfied: tzdata>=2022.1 in ./.venv/lib/python3.9/site-packages (from pandas==2.0.3->-r requirements.txt (line 1)) (2025.1)\n", | |
"Requirement already satisfied: filelock in ./.venv/lib/python3.9/site-packages (from huggingface_hub->-r requirements.txt (line 4)) (3.17.0)\n", | |
"Requirement already satisfied: packaging>=20.9 in ./.venv/lib/python3.9/site-packages (from huggingface_hub->-r requirements.txt (line 4)) (24.2)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in ./.venv/lib/python3.9/site-packages (from huggingface_hub->-r requirements.txt (line 4)) (6.0.2)\n", | |
"Requirement already satisfied: requests in ./.venv/lib/python3.9/site-packages (from huggingface_hub->-r requirements.txt (line 4)) (2.32.3)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in ./.venv/lib/python3.9/site-packages (from huggingface_hub->-r requirements.txt (line 4)) (4.12.2)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in ./.venv/lib/python3.9/site-packages (from transformers->-r requirements.txt (line 5)) (2024.11.6)\n", | |
"Requirement already satisfied: tokenizers<0.22,>=0.21 in ./.venv/lib/python3.9/site-packages (from transformers->-r requirements.txt (line 5)) (0.21.0)\n", | |
"Requirement already satisfied: safetensors>=0.4.1 in ./.venv/lib/python3.9/site-packages (from transformers->-r requirements.txt (line 5)) (0.5.2)\n", | |
"Requirement already satisfied: networkx in ./.venv/lib/python3.9/site-packages (from torch->-r requirements.txt (line 6)) (3.2.1)\n", | |
"Requirement already satisfied: jinja2 in ./.venv/lib/python3.9/site-packages (from torch->-r requirements.txt (line 6)) (3.1.5)\n", | |
"Requirement already satisfied: sympy==1.13.1 in ./.venv/lib/python3.9/site-packages (from torch->-r requirements.txt (line 6)) (1.13.1)\n", | |
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.venv/lib/python3.9/site-packages (from sympy==1.13.1->torch->-r requirements.txt (line 6)) (1.3.0)\n", | |
"Collecting scikit-learn (from sentence_transformers->-r requirements.txt (line 10))\n", | |
" Using cached scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl.metadata (31 kB)\n", | |
"Collecting scipy (from sentence_transformers->-r requirements.txt (line 10))\n", | |
" Downloading scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl.metadata (60 kB)\n", | |
"Requirement already satisfied: Pillow in ./.venv/lib/python3.9/site-packages (from sentence_transformers->-r requirements.txt (line 10)) (11.1.0)\n", | |
"Requirement already satisfied: six>=1.5 in ./.venv/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas==2.0.3->-r requirements.txt (line 1)) (1.17.0)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in ./.venv/lib/python3.9/site-packages (from jinja2->torch->-r requirements.txt (line 6)) (3.0.2)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in ./.venv/lib/python3.9/site-packages (from requests->huggingface_hub->-r requirements.txt (line 4)) (3.4.1)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in ./.venv/lib/python3.9/site-packages (from requests->huggingface_hub->-r requirements.txt (line 4)) (3.10)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in ./.venv/lib/python3.9/site-packages (from requests->huggingface_hub->-r requirements.txt (line 4)) (2.3.0)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in ./.venv/lib/python3.9/site-packages (from requests->huggingface_hub->-r requirements.txt (line 4)) (2025.1.31)\n", | |
"Collecting joblib>=1.2.0 (from scikit-learn->sentence_transformers->-r requirements.txt (line 10))\n", | |
" Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)\n", | |
"Collecting threadpoolctl>=3.1.0 (from scikit-learn->sentence_transformers->-r requirements.txt (line 10))\n", | |
" Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)\n", | |
"Downloading sentence_transformers-3.4.1-py3-none-any.whl (275 kB)\n", | |
"Using cached scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl (11.1 MB)\n", | |
"Downloading scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl (30.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m30.3/30.3 MB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", | |
"\u001b[?25hUsing cached joblib-1.4.2-py3-none-any.whl (301 kB)\n", | |
"Downloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)\n", | |
"Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn, sentence_transformers\n", | |
"Successfully installed joblib-1.4.2 scikit-learn-1.6.1 scipy-1.13.1 sentence_transformers-3.4.1 threadpoolctl-3.5.0\n", | |
"Note: you may need to restart the kernel to use updated packages.\n" | |
] | |
} | |
], | |
"source": [ | |
"%pip install -r requirements.txt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/raohai/projj/github.com/RaoHai/sql_cluster/.venv/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", | |
" warnings.warn(\n", | |
"/Users/raohai/projj/github.com/RaoHai/sql_cluster/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<bound method NDFrame.head of id domain \\\n", | |
"0 5097 forestry \n", | |
"1 5098 defense industry \n", | |
"2 5099 marine biology \n", | |
"3 5100 financial services \n", | |
"4 5101 energy \n", | |
"... ... ... \n", | |
"99995 89651 nonprofit \n", | |
"99996 89652 retail \n", | |
"99997 89653 fitness industry \n", | |
"99998 89654 space exploration \n", | |
"99999 89655 wildlife conservation \n", | |
"\n", | |
" domain_description sql_complexity \\\n", | |
"0 Comprehensive data on sustainable forest manag... single join \n", | |
"1 Defense contract data, military equipment main... aggregation \n", | |
"2 Comprehensive data on marine species, oceanogr... basic SQL \n", | |
"3 Detailed financial data including investment s... aggregation \n", | |
"4 Energy market data covering renewable energy s... window functions \n", | |
"... ... ... \n", | |
"99995 Nonprofit data on charitable giving trends, so... basic SQL \n", | |
"99996 Retail data on circular supply chains, ethical... single join \n", | |
"99997 Workout data, membership demographics, wearabl... single join \n", | |
"99998 Spacecraft manufacturing data, space mission r... single join \n", | |
"99999 Animal population data, habitat preservation e... basic SQL \n", | |
"\n", | |
" sql_complexity_description \\\n", | |
"0 only one join (specify inner, outer, cross) \n", | |
"1 aggregation functions (COUNT, SUM, AVG, MIN, M... \n", | |
"2 basic SQL with a simple select statement \n", | |
"3 aggregation functions (COUNT, SUM, AVG, MIN, M... \n", | |
"4 window functions (e.g., ROW_NUMBER, LEAD, LAG,... \n", | |
"... ... \n", | |
"99995 basic SQL with a simple select statement \n", | |
"99996 only one join (specify inner, outer, cross) \n", | |
"99997 only one join (specify inner, outer, cross) \n", | |
"99998 only one join (specify inner, outer, cross) \n", | |
"99999 basic SQL with a simple select statement \n", | |
"\n", | |
" sql_task_type \\\n", | |
"0 analytics and reporting \n", | |
"1 analytics and reporting \n", | |
"2 analytics and reporting \n", | |
"3 analytics and reporting \n", | |
"4 analytics and reporting \n", | |
"... ... \n", | |
"99995 analytics and reporting \n", | |
"99996 analytics and reporting \n", | |
"99997 analytics and reporting \n", | |
"99998 analytics and reporting \n", | |
"99999 analytics and reporting \n", | |
"\n", | |
" sql_task_type_description \\\n", | |
"0 generating reports, dashboards, and analytical... \n", | |
"1 generating reports, dashboards, and analytical... \n", | |
"2 generating reports, dashboards, and analytical... \n", | |
"3 generating reports, dashboards, and analytical... \n", | |
"4 generating reports, dashboards, and analytical... \n", | |
"... ... \n", | |
"99995 generating reports, dashboards, and analytical... \n", | |
"99996 generating reports, dashboards, and analytical... \n", | |
"99997 generating reports, dashboards, and analytical... \n", | |
"99998 generating reports, dashboards, and analytical... \n", | |
"99999 generating reports, dashboards, and analytical... \n", | |
"\n", | |
" sql_prompt \\\n", | |
"0 What is the total volume of timber sold by eac... \n", | |
"1 List all the unique equipment types and their ... \n", | |
"2 How many marine species are found in the South... \n", | |
"3 What is the total trade value and average pric... \n", | |
"4 Find the energy efficiency upgrades with the h... \n", | |
"... ... \n", | |
"99995 Which programs had the highest volunteer parti... \n", | |
"99996 What is the number of fair-trade certified acc... \n", | |
"99997 Find the user with the longest workout session... \n", | |
"99998 How many space missions were completed by each... \n", | |
"99999 Determine the number of unique animal species ... \n", | |
"\n", | |
" sql_context \\\n", | |
"0 CREATE TABLE salesperson (salesperson_id INT, ... \n", | |
"1 CREATE TABLE equipment_maintenance (equipment_... \n", | |
"2 CREATE TABLE marine_species (name VARCHAR(50),... \n", | |
"3 CREATE TABLE trade_history (id INT, trader_id ... \n", | |
"4 CREATE TABLE upgrades (id INT, cost FLOAT, typ... \n", | |
"... ... \n", | |
"99995 CREATE TABLE programs (program_id INT, num_vol... \n", | |
"99996 CREATE TABLE products (product_id INT, product... \n", | |
"99997 CREATE TABLE workout_sessions (id INT, user_id... \n", | |
"99998 CREATE TABLE SpaceMissions (id INT, astronaut_... \n", | |
"99999 CREATE TABLE animal_population (id INT, animal... \n", | |
"\n", | |
" sql \\\n", | |
"0 SELECT salesperson_id, name, SUM(volume) as to... \n", | |
"1 SELECT equipment_type, SUM(maintenance_frequen... \n", | |
"2 SELECT COUNT(*) FROM marine_species WHERE loca... \n", | |
"3 SELECT trader_id, stock, SUM(price * quantity)... \n", | |
"4 SELECT type, cost FROM (SELECT type, cost, ROW... \n", | |
"... ... \n", | |
"99995 SELECT program_id, (num_volunteers / total_par... \n", | |
"99996 SELECT COUNT(*) FROM products WHERE is_fair_tr... \n", | |
"99997 SELECT u.name, MAX(session_duration) as max_du... \n", | |
"99998 SELECT a.name, COUNT(sm.id) FROM Astronauts a ... \n", | |
"99999 SELECT COUNT(DISTINCT animal_species) AS uniqu... \n", | |
"\n", | |
" sql_explanation \n", | |
"0 Joins timber_sales and salesperson tables, gro... \n", | |
"1 This query groups the equipment_maintenance ta... \n", | |
"2 This query counts the number of marine species... \n", | |
"3 This query calculates the total trade value an... \n", | |
"4 The SQL query uses the ROW_NUMBER function to ... \n", | |
"... ... \n", | |
"99995 This query calculates the participation rate f... \n", | |
"99996 The query counts the number of fair-trade cert... \n", | |
"99997 The query joins the workout_sessions and users... \n", | |
"99998 This query calculates the number of space miss... \n", | |
"99999 This query calculates the number of unique ani... \n", | |
"\n", | |
"[100000 rows x 11 columns]>" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"\n", | |
"import pandas as pd\n", | |
"\n", | |
"splits = {'train': 'synthetic_text_to_sql_train.snappy.parquet', 'test': 'synthetic_text_to_sql_test.snappy.parquet'}\n", | |
"df = pd.read_parquet(\"hf://datasets/gretelai/synthetic_text_to_sql/\" + splits[\"train\"])\n", | |
"\n", | |
"df.head" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import pipeline\n", | |
"pipe = pipeline(\"feature-extraction\", model=\"microsoft/deberta-v3-base\") " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Embedding size: (1, 384)\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"\n", | |
"sql_query = \"SELECT name, age FROM users WHERE age > 30\"\n", | |
"embeddings = pipe(sql_query)\n", | |
"\n", | |
"embedding_size = np.array(embeddings).shape\n", | |
"print(f\"Embedding size: {embedding_size}\")" | |
] | |
}, | |
{ | |
"attachments": {}, | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Using FAISS as Vector database" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"top_sql_queries=0 SELECT salesperson_id, name, SUM(volume) as to...\n", | |
"1 SELECT equipment_type, SUM(maintenance_frequen...\n", | |
"2 SELECT COUNT(*) FROM marine_species WHERE loca...\n", | |
"3 SELECT trader_id, stock, SUM(price * quantity)...\n", | |
"4 SELECT type, cost FROM (SELECT type, cost, ROW...\n", | |
"5 SELECT SUM(spending) FROM defense.eu_humanitar...\n", | |
"6 SELECT SpeciesName, AVG(WaterTemp) as AvgTemp ...\n", | |
"7 DELETE FROM Program_Outcomes WHERE program_id ...\n", | |
"8 SELECT SUM(fare) FROM bus_routes WHERE route_n...\n", | |
"9 SELECT AVG(Property_Size) FROM Inclusive_Housi...\n", | |
"Name: sql, dtype: object, text_dict={0: 'SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;', 1: 'SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;', 2: \"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';\", 3: 'SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;', 4: 'SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;', 5: 'SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021;', 6: 'SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY SpeciesName;', 7: 'DELETE FROM Program_Outcomes WHERE program_id = 1002;', 8: \"SELECT SUM(fare) FROM bus_routes WHERE route_name = 'Green Line';\", 9: \"SELECT AVG(Property_Size) FROM Inclusive_Housing WHERE Inclusive = 'Yes';\"}\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/10 [00:00<?, ?it/s]\n" | |
] | |
}, | |
{ | |
"ename": "ValueError", | |
"evalue": "not enough values to unpack (expected 2, got 1)", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[34], line 22\u001b[0m\n\u001b[1;32m 19\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(embeddings[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;66;03m# 获取第一个元素,即每个 token 的 embedding\u001b[39;00m\n\u001b[1;32m 20\u001b[0m sentence_embedding \u001b[38;5;241m=\u001b[39m embeddings\u001b[38;5;241m.\u001b[39mmean(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# 对所有 token 的 embedding 取平均\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m \u001b[43mindex\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand_dims\u001b[49m\u001b[43m(\u001b[49m\u001b[43msentence_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfloat32\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# 添加到 FAISS 索引中\u001b[39;00m\n", | |
"File \u001b[0;32m~/projj/github.com/RaoHai/sql_cluster/.venv/lib/python3.9/site-packages/faiss/class_wrappers.py:227\u001b[0m, in \u001b[0;36mhandle_Index.<locals>.replacement_add\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mreplacement_add\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 215\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Adds vectors to the index.\u001b[39;00m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124;03m The index must be trained before vectors can be added to it.\u001b[39;00m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;124;03m The vectors are implicitly numbered in sequence. When `n` vectors are\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;124;03m `dtype` must be float32.\u001b[39;00m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 227\u001b[0m n, d \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m d \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md\n\u001b[1;32m 229\u001b[0m x \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mascontiguousarray(x, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", | |
"\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 2, got 1)" | |
] | |
} | |
], | |
"source": [ | |
"import faiss\n", | |
"import numpy as np\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"\n", | |
"d = 384 # 向量维度\n", | |
"\n", | |
"top_sql_queries = df.head(10)[\"sql\"] # 假设字段名称是 sql_query\n", | |
"\n", | |
"# 使用一个字典来存储原文与其对应的索引\n", | |
"text_dict = {i: top_sql_queries[i] for i in range(len(top_sql_queries))}\n", | |
"print(f\"top_sql_queries={top_sql_queries}, text_dict={text_dict}\")\n", | |
"\n", | |
"index = faiss.IndexFlatL2(d) # 使用 L2 距离\n", | |
"\n", | |
"for sql_query in tqdm(top_sql_queries):\n", | |
" embeddings = model.encode([sql_query])\n", | |
" # 转换嵌入为 NumPy 数组并计算平均值(句子的整体 embedding)\n", | |
" embeddings = np.array(embeddings[0]) # 获取第一个元素,即每个 token 的 embedding\n", | |
" sentence_embedding = embeddings.mean(axis=0) # 对所有 token 的 embedding 取平均\n", | |
" \n", | |
" index.add(np.expand_dims(sentence_embedding, axis=0).astype('float32')) # 添加到 FAISS 索引中" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's try to retrive" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Shape of embeddings: (1, 125, 768)\n", | |
"Shape of sentence embedding: (1, 768)\n", | |
"Shape of query vector for FAISS: (1, 768)\n", | |
"query=\n", | |
" SELECT SUM(volume) AS total_volume,\n", | |
" salesperson_id,\n", | |
" name\n", | |
" FROM timber_sales\n", | |
" JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id\n", | |
" GROUP BY salesperson_id,\n", | |
" name\n", | |
" ORDER BY total_volume DESC;\n", | |
"\n", | |
"SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;\n", | |
"SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;\n", | |
"SELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY SpeciesName;\n", | |
"SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;\n", | |
"SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021;\n" | |
] | |
} | |
], | |
"source": [ | |
"# 假设你已经加载了数据并取得了前 1 条 SQL 查询\n", | |
"xq = \"\"\"\n", | |
" SELECT SUM(volume) AS total_volume,\n", | |
" salesperson_id,\n", | |
" name\n", | |
" FROM timber_sales\n", | |
" JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id\n", | |
" GROUP BY salesperson_id,\n", | |
" name\n", | |
" ORDER BY total_volume DESC;\n", | |
"\"\"\"\n", | |
"\n", | |
"# 计算 SQL 查询的 embeddings\n", | |
"xq_embeddings = pipe(xq)\n", | |
"xq_embeddings = np.array(xq_embeddings)\n", | |
"\n", | |
"# 确保 embeddings 的形状为 (1, num_tokens, embedding_dim)\n", | |
"print(f\"Shape of embeddings: {xq_embeddings.shape}\") # 形状应该是 (1, 63, 768)\n", | |
"\n", | |
"# 对所有 token embeddings 取平均,得到句子的 embedding\n", | |
"xq_sentence_embedding = xq_embeddings.mean(axis=1) # 对维度 1 取平均,得到 (1, 768)\n", | |
"\n", | |
"# 确认句子的 embedding 形状,应该是 (768,)\n", | |
"print(f\"Shape of sentence embedding: {xq_sentence_embedding.shape}\") # 应该是 (768,)\n", | |
"\n", | |
"# 将句子的 embedding 转换为 (1, 768) 形式\n", | |
"xq_sentence_vector = xq_sentence_embedding.astype('float32') # 不需要 expand_dims,因为它已经是二维的\n", | |
"\n", | |
"# 确保索引的维度与 embedding 的维度匹配\n", | |
"print(f\"Shape of query vector for FAISS: {xq_sentence_vector.shape}\") # 应该是 (1, 768)\n", | |
"\n", | |
"# 查询索引,返回最相似的 5 个向量\n", | |
"D, I = index.search(xq_sentence_vector, 5)\n", | |
"\n", | |
"print(f\"query={xq}\")\n", | |
"# 打印最相似的向量的原文\n", | |
"for idx in I[0]:\n", | |
" print(f\"{text_dict[idx]}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Now Try Clustering" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 10000/10000 [03:43<00:00, 44.70it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Clustering 10000 points in 768D to 10 clusters, redo 1 times, 20 iterations\n", | |
" Preprocessing in 0.00 s\n", | |
"centroids shape before reshape: (7680,) objective=57242.6 imbalance=1.146 nsplit=0 \n", | |
"聚类中心:\n", | |
"[[-0.32826853 0.07509142 0.19165924 ... -0.49869946 -0.44887027\n", | |
" 0.5077467 ]\n", | |
" [-0.28716213 0.08168618 0.12497564 ... -0.48048055 -0.38728404\n", | |
" 0.42993832]\n", | |
" [-0.2827714 0.1068728 0.19573389 ... -0.3864083 -0.4861318\n", | |
" 0.46620542]\n", | |
" ...\n", | |
" [-0.3292238 0.0759559 0.17924066 ... -0.49496964 -0.45242766\n", | |
" 0.44715384]\n", | |
" [-0.35065335 0.20159101 0.2598137 ... -0.6274775 -0.46147823\n", | |
" 0.5497009 ]\n", | |
" [-0.3542094 0.14758337 0.25081795 ... -0.59235317 -0.4609816\n", | |
" 0.47140077]]\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"import faiss\n", | |
"import numpy as np\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"\n", | |
"# 设置聚类的簇数\n", | |
"k = 10 # 假设我们想要10个簇\n", | |
"d = 768 # 向量维度\n", | |
"\n", | |
"top_sql_queries = df.head(10000)[\"sql\"] # 假设字段名称是 sql_query\n", | |
"\n", | |
"clustering = faiss.Clustering(d, k)\n", | |
"\n", | |
"# 设置聚类的参数(例如,最大迭代次数)\n", | |
"clustering.niter = 20 # 设置最大迭代次数\n", | |
"clustering.max_points_per_centroid = 1000 # 每个质心的最大点数\n", | |
"clustering.verbose = True # 输出详细信息\n", | |
"\n", | |
"# 使用一个字典来存储原文与其对应的索引\n", | |
"text_dict = {i: top_sql_queries[i] for i in range(len(top_sql_queries))}\n", | |
"\n", | |
"index = faiss.IndexHNSWFlat(d, 32) # 使用 IndexFlatIP 距离\n", | |
"\n", | |
"vectors = []\n", | |
"for sql_query in tqdm(top_sql_queries):\n", | |
" embeddings = pipe(sql_query)\n", | |
" # 转换嵌入为 NumPy 数组并计算平均值(句子的整体 embedding)\n", | |
" embeddings = np.array(embeddings[0]) # 获取第一个元素,即每个 token 的 embedding\n", | |
" sentence_embedding = embeddings.mean(axis=0) # 对所有 token 的 embedding 取平均\n", | |
" \n", | |
" vector = np.expand_dims(sentence_embedding, axis=0).astype('float32')\n", | |
" vectors.append(vector)\n", | |
"\n", | |
"# 将所有向量堆叠成一个大的数组\n", | |
"vectors = np.vstack(vectors) # (num_samples, d)\n", | |
"\n", | |
"# 执行聚类训练\n", | |
"clustering.train(vectors, index) # 进行聚类训练\n", | |
"\n", | |
"# 获取聚类中心\n", | |
"centroids = faiss.vector_float_to_array(clustering.centroids) # 获取聚类中心\n", | |
"print(f\"centroids shape before reshape: {centroids.shape}\")\n", | |
"\n", | |
"# 检查是否可以 reshape\n", | |
"if centroids.shape[0] == k * d:\n", | |
" centroids = centroids.reshape(k, d) # reshape 成 k 个簇中心,每个簇中心是 d 维\n", | |
" print(f\"聚类中心:\\n{centroids}\")\n", | |
"else:\n", | |
" print(f\"聚类中心的形状不符合预期,实际大小为 {centroids.shape}\")\n", | |
"\n", | |
"# 添加向量到索引\n", | |
"index.add(vectors) # 聚类后添加向量\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"簇 0 的相邻 SQL 查询:\n", | |
"\tSELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC; (距离: 0.0)\n", | |
"\tSELECT (COUNT(CASE WHEN gender = 'male' THEN 1 END) / COUNT(*) * 100) AS male_percentage, (COUNT(CASE WHEN gender = 'female' THEN 1 END) / COUNT(*) * 100) AS female_percentage FROM employee; (距离: 2.289557695388794)\n", | |
"\tSELECT COUNT(*) FROM customer_complaints WHERE complaint_date >= CURDATE() - INTERVAL 60 DAY; (距离: 2.2937240600585938)\n", | |
"\tSELECT COUNT(*) FROM posts WHERE brand_mentioned = 'Apple' AND industry = 'technology' AND country = 'Japan' AND post_time > DATE_SUB(NOW(), INTERVAL 1 WEEK); (距离: 2.325535297393799)\n", | |
"\tSELECT AVG(ts.ticket_price) as avg_ticket_price FROM ticket_sales ts WHERE ts.team_id = 2; (距离: 2.5796945095062256)\n", | |
"\n", | |
"\n", | |
"簇 1 的相邻 SQL 查询:\n", | |
"\tSELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type; (距离: 0.0)\n", | |
"\tSELECT Metric, MIN(Value) as MinValue FROM HealthEquityMetrics GROUP BY Metric; (距离: 3.417145013809204)\n", | |
"\tSELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1; (距离: 3.496406078338623)\n", | |
"\tSELECT company_name, COUNT(*) FROM veteran_job_applications WHERE application_date BETWEEN '2019-07-01' AND '2019-09-30' GROUP BY company_name; (距离: 3.645045518875122)\n", | |
"\tSELECT cf.country, SUM(cf.amount) FROM climate_finance cf WHERE cf.sector = 'mitigation' AND cf.region = 'Asia-Pacific' GROUP BY cf.country ORDER BY SUM(cf.amount) DESC LIMIT 5; (距离: 3.786212921142578)\n", | |
"\n", | |
"\n", | |
"簇 2 的相邻 SQL 查询:\n", | |
"\tSELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean'; (距离: 0.0)\n", | |
"\tSELECT cost FROM ProjectsNY WHERE state = 'New York' UNION SELECT cost FROM ProjectsNJ WHERE state = 'New Jersey' (距离: 1.843540906906128)\n", | |
"\tSELECT type, COUNT(*) as total, AVG(DATEDIFF(end_date, start_date)) as avg_duration FROM projects WHERE type IN ('Wind Farm', 'Hydro Plant') GROUP BY type; (距离: 2.3331053256988525)\n", | |
"\tSELECT AVG(patients.age) FROM patients INNER JOIN support_groups ON patients.id = support_groups.patient_id WHERE patients.condition = 'Schizophrenia' AND patients.country = 'Europe'; (距离: 2.3665666580200195)\n", | |
"\tINSERT INTO autonomous_testing (id, location, company, date, miles_driven) VALUES (1, 'San Francisco', 'Waymo', '2022-05-10', 500), (2, 'San Francisco', 'Cruise', '2022-05-11', 600); (距离: 2.4107885360717773)\n", | |
"\n", | |
"\n", | |
"簇 3 的相邻 SQL 查询:\n", | |
"\tSELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock; (距离: 0.0)\n", | |
"\tSELECT SUM(value) FROM MilitaryEquipmentSales WHERE year = 2023 AND country LIKE 'Southeast%'; (距离: 1.9893628358840942)\n", | |
"\tSELECT region_sold, COUNT(DISTINCT Products.product_id) as product_count FROM Sales INNER JOIN Products ON Sales.product_id = Products.product_id GROUP BY region_sold; (距离: 1.993025541305542)\n", | |
"\tSELECT SUM(hours_flown) FROM aircraft_and_flight_hours WHERE units_manufactured > 500; (距离: 2.0590567588806152)\n", | |
"\tSELECT region, num_stalls, MAX(id) FROM charging_stations GROUP BY region, num_stalls; (距离: 2.1622533798217773)\n", | |
"\n", | |
"\n", | |
"簇 4 的相邻 SQL 查询:\n", | |
"\tSELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1; (距离: 0.0)\n", | |
"\tSELECT DATE_PART('month', e.HireDate) as Month, COUNT(DISTINCT EmployeeID) as NumberHired FROM Employees e WHERE e.HireDate >= (CURRENT_DATE - INTERVAL '1 year') GROUP BY Month ORDER BY Month; (距离: 2.1936867237091064)\n", | |
"\tSELECT AVG(Age) FROM Employees WHERE LeadershipTraining = TRUE; (距离: 2.1996164321899414)\n", | |
"\tSELECT SUM(CargoShips.capacity) FROM CargoShips INNER JOIN Fleet ON CargoShips.fleet_id = Fleet.id; (距离: 2.267169952392578)\n", | |
"\tSELECT COUNT(*) FROM hospitals WHERE state = 'California'; (距离: 2.420313835144043)\n", | |
"\n", | |
"\n", | |
"簇 5 的相邻 SQL 查询:\n", | |
"\tSELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021; (距离: 0.0)\n", | |
"\tSELECT SUM(Cost) FROM Training WHERE TrainingDate >= DATE_SUB(CURDATE(), INTERVAL 1 YEAR) AND Community IN ('Women in Tech', 'LGBTQ+', 'Minorities in STEM'); (距离: 1.9522778987884521)\n", | |
"\tSELECT country_code, project_type, COUNT(*) AS num_projects, SUM(completed) AS num_completed FROM project WHERE project_end_date >= (CURRENT_DATE - INTERVAL '5 years') GROUP BY country_code, project_type; (距离: 1.9620110988616943)\n", | |
"\tUPDATE brand_material_ratings SET rating = 5 WHERE brand = 'Brand C' AND material = 'recycled polyester'; (距离: 2.040379047393799)\n", | |
"\tSELECT vehicle_name, safety_rating FROM (SELECT vehicle_name, safety_rating, RANK() OVER (ORDER BY safety_rating DESC) as safety_rank FROM auto_show WHERE vehicle_name LIKE '%Autonomous%') AS auton_ranks WHERE safety_rank = 1; (距离: 2.0645792484283447)\n", | |
"\n", | |
"\n", | |
"簇 6 的相邻 SQL 查询:\n", | |
"\tSELECT SpeciesName, AVG(WaterTemp) as AvgTemp FROM SpeciesWaterTemp INNER JOIN FishSpecies ON SpeciesWaterTemp.SpeciesID = FishSpecies.SpeciesID WHERE MONTH(Date) = 2 GROUP BY SpeciesName; (距离: 0.0)\n", | |
"\tSELECT DATE_FORMAT(therapy_date, '%Y-%m') as quarter, COUNT(*) as count FROM therapy_sessions WHERE therapy_date BETWEEN DATE_SUB(CURRENT_DATE, INTERVAL 1 YEAR) AND CURRENT_DATE GROUP BY quarter; (距离: 1.6734617948532104)\n", | |
"\tSELECT COUNT(*) FROM biotech_startups WHERE location = 'USA' AND quarter IN (1, 2) AND year = 2022; (距离: 1.7494053840637207)\n", | |
"\tSELECT COUNT(*) FROM ( (SELECT * FROM mining_company) UNION ALL (SELECT * FROM drilling_company) ) AS total; (距离: 1.757635235786438)\n", | |
"\tSELECT state, MIN(mental_health_score) FROM students GROUP BY state; (距离: 1.7892229557037354)\n", | |
"\n", | |
"\n", | |
"簇 7 的相邻 SQL 查询:\n", | |
"\tDELETE FROM Program_Outcomes WHERE program_id = 1002; (距离: 0.0)\n", | |
"\tSELECT MAX(price) FROM thulium_prices WHERE country = 'Indonesia' AND year BETWEEN 2017 AND 2021; (距离: 2.253490924835205)\n", | |
"\tSELECT route_id, MIN(DATEDIFF('day', LAG(clean_date) OVER (PARTITION BY route_id ORDER BY clean_date), clean_date)) FROM buses GROUP BY route_id; (距离: 2.3461976051330566)\n", | |
"\t索引 10009 对应的原文不存在!\n", | |
"\tSELECT id, SUM(balance) OVER (ORDER BY id) FROM customers; (距离: 2.4693894386291504)\n", | |
"\n", | |
"\n", | |
"簇 8 的相邻 SQL 查询:\n", | |
"\tSELECT SUM(fare) FROM bus_routes WHERE route_name = 'Green Line'; (距离: 0.0)\n", | |
"\tSELECT city, AVG(price) FROM tickets t JOIN teams te ON t.team_id = te.team_id GROUP BY city; (距离: 3.078681707382202)\n", | |
"\tSELECT name FROM (SELECT name, sustainability_score, ROW_NUMBER() OVER (ORDER BY sustainability_score DESC) as rn FROM Consumers WHERE year = 2020) t WHERE rn <= 5; (距离: 3.140638828277588)\n", | |
"\tUPDATE ElectricVehicles SET Horsepower = 160 WHERE Make = 'Nissan' AND Model = 'Leaf'; (距离: 3.253568649291992)\n", | |
"\tSELECT COUNT(*) FROM Events WHERE city IN ('Los Angeles', 'San Francisco') AND type = 'Art Exhibition'; (距离: 3.3343262672424316)\n", | |
"\n", | |
"\n", | |
"簇 9 的相邻 SQL 查询:\n", | |
"\tSELECT AVG(Property_Size) FROM Inclusive_Housing WHERE Inclusive = 'Yes'; (距离: 0.0)\n", | |
"\tSELECT State, (SUM(Obese) / SUM(Population)) * 100 AS ObesityRate FROM MexicanStates GROUP BY State; (距离: 2.0751421451568604)\n", | |
"\tSELECT DATEPART(QUARTER, ApplicationDate) AS Quarter, Source, COUNT(*) FROM JobApplications WHERE ApplicationDate >= DATEADD(QUARTER, -1, GETDATE()) GROUP BY Source, DATEPART(QUARTER, ApplicationDate); (距离: 2.201944351196289)\n", | |
"\tSELECT Author, COUNT(*) AS Total_Papers FROM Research_Papers WHERE Autonomous_Driving_Research = TRUE GROUP BY Author; (距离: 2.266307830810547)\n", | |
"\tSELECT MAX(score) FROM financial_wellbeing WHERE country = 'USA'; (距离: 2.3806111812591553)\n", | |
"\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# 获取每个簇的相邻 SQL 查询\n", | |
"k_nearest_neighbors = 5 # 每个簇返回 5 个最相似的向量(即 SQL 查询)\n", | |
"for i in range(k):\n", | |
" # 找到与第 i 个簇中心最接近的向量\n", | |
" centroid_vector = np.expand_dims(centroids[i], axis=0).astype('float32') # 当前簇中心\n", | |
" distances, indices = index.search(centroid_vector, k_nearest_neighbors) # 查找最近的k个向量\n", | |
" \n", | |
" print(f\"簇 {i} 的相邻 SQL 查询:\")\n", | |
" for j in range(k_nearest_neighbors):\n", | |
" closest_index = indices[0][j] # 获取最近的向量索引\n", | |
" \n", | |
" # 检查索引是否在 text_dict 中\n", | |
" if closest_index in text_dict:\n", | |
" closest_text = text_dict[closest_index] # 获取该向量的原文\n", | |
" print(f\"\\t{closest_text} (距离: {distances[0][j]})\")\n", | |
" else:\n", | |
" print(f\"\\t索引 {closest_index} 对应的原文不存在!\")\n", | |
" print(\"\\n\")\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment