Skip to content

Instantly share code, notes, and snippets.

@aneesha
Last active August 9, 2023 00:48
Show Gist options
  • Save aneesha/e2c7c4f4b7f0138d0c6bdbe6fdb1f764 to your computer and use it in GitHub Desktop.
Save aneesha/e2c7c4f4b7f0138d0c6bdbe6fdb1f764 to your computer and use it in GitHub Desktop.
Semantic Search with Sentence-BERT
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "SiameseBERT-SemanticSearch.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "liIk8f980xT5",
"colab_type": "text"
},
"source": [
"# Sentence Embeddings using Siamese BERT-Networks\n",
"---\n",
"This Google Colab Notebook illustrates using the Sentence Transformer python library to quickly create BERT embeddings for sentences and perform fast semantic searches.\n",
"\n",
"The Sentence Transformer library is available on [pypi](https://pypi.org/project/sentence-transformers/) and [github](https://github.com/UKPLab/sentence-transformers). The library implements code from the ACL 2019 paper entitled \"[Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://www.aclweb.org/anthology/D19-1410.pdf)\" by Nils Reimers and Iryna Gurevych.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bzljhyTQEZds",
"colab_type": "text"
},
"source": [
"## Install Sentence Transformer Library"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AmxRYxNDvn6y",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 768
},
"outputId": "81694439-a0c9-45f1-88e6-ea77aa94f39a"
},
"source": [
"# Install the library using pip\n",
"!pip install sentence-transformers"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting sentence-transformers\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/c9/91/c85ddef872d5bb39949386930c1f834ac382e145fcd30155b09d6fb65c5a/sentence-transformers-0.2.5.tar.gz (49kB)\n",
"\r\u001b[K |██████▋ | 10kB 16.5MB/s eta 0:00:01\r\u001b[K |█████████████▎ | 20kB 1.9MB/s eta 0:00:01\r\u001b[K |███████████████████▉ | 30kB 2.5MB/s eta 0:00:01\r\u001b[K |██████████████████████████▌ | 40kB 1.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 51kB 1.7MB/s \n",
"\u001b[?25hCollecting transformers==2.3.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/50/10/aeefced99c8a59d828a92cc11d213e2743212d3641c87c82d61b035a7d5c/transformers-2.3.0-py3-none-any.whl (447kB)\n",
"\r\u001b[K |▊ | 10kB 17.9MB/s eta 0:00:01\r\u001b[K |█▌ | 20kB 23.9MB/s eta 0:00:01\r\u001b[K |██▏ | 30kB 30.7MB/s eta 0:00:01\r\u001b[K |███ | 40kB 16.3MB/s eta 0:00:01\r\u001b[K |███▋ | 51kB 4.7MB/s eta 0:00:01\r\u001b[K |████▍ | 61kB 5.5MB/s eta 0:00:01\r\u001b[K |█████▏ | 71kB 6.3MB/s eta 0:00:01\r\u001b[K |█████▉ | 81kB 6.8MB/s eta 0:00:01\r\u001b[K |██████▋ | 92kB 7.5MB/s eta 0:00:01\r\u001b[K |███████▎ | 102kB 7.2MB/s eta 0:00:01\r\u001b[K |████████ | 112kB 7.2MB/s eta 0:00:01\r\u001b[K |████████▉ | 122kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████▌ | 133kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████▎ | 143kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████ | 153kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████▊ | 163kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████▌ | 174kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████▏ | 184kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████ | 194kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 204kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████▍ | 215kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 225kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 235kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████████▋ | 245kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 256kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████ | 266kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 276kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 286kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 296kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 307kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 317kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 327kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 337kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 348kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 358kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 368kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 378kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 389kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 399kB 7.2MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 409kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 419kB 7.2MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 430kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 440kB 7.2MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 450kB 7.2MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (4.28.1)\n",
"Requirement already satisfied: torch>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (1.3.1)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (1.17.5)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (0.22.1)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (1.4.1)\n",
"Requirement already satisfied: nltk in /usr/local/lib/python3.6/dist-packages (from sentence-transformers) (3.2.5)\n",
"Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0->sentence-transformers) (1.10.47)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0->sentence-transformers) (2019.12.20)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0->sentence-transformers) (2.21.0)\n",
"Collecting sacremoses\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)\n",
"\u001b[K |████████████████████████████████| 870kB 26.9MB/s \n",
"\u001b[?25hCollecting sentencepiece\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
"\u001b[K |████████████████████████████████| 1.0MB 42.6MB/s \n",
"\u001b[?25hRequirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sentence-transformers) (0.14.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from nltk->sentence-transformers) (1.12.0)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0->sentence-transformers) (0.9.4)\n",
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0->sentence-transformers) (0.2.1)\n",
"Requirement already satisfied: botocore<1.14.0,>=1.13.47 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0->sentence-transformers) (1.13.47)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.3.0->sentence-transformers) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.3.0->sentence-transformers) (2019.11.28)\n",
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.3.0->sentence-transformers) (1.24.3)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==2.3.0->sentence-transformers) (2.8)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.3.0->sentence-transformers) (7.0)\n",
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0->sentence-transformers) (0.15.2)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0->sentence-transformers) (2.6.1)\n",
"Building wheels for collected packages: sentence-transformers, sacremoses\n",
" Building wheel for sentence-transformers (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for sentence-transformers: filename=sentence_transformers-0.2.5-cp36-none-any.whl size=64943 sha256=973f7a74060aad6ed6f617b3879e3340babb4d0ecd2d47508f1538ef02d1a587\n",
" Stored in directory: /root/.cache/pip/wheels/b4/ce/39/5bbda8ac34eb52df8c6531382ca077773fbfcbfb6386e5d66c\n",
" Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for sacremoses: filename=sacremoses-0.0.38-cp36-none-any.whl size=884629 sha256=63585f2eb5e65adb473e2e282bda1bdc423a86bfde2a175308c56709911f6678\n",
" Stored in directory: /root/.cache/pip/wheels/6d/ec/1a/21b8912e35e02741306f35f66c785f3afe94de754a0eaf1422\n",
"Successfully built sentence-transformers sacremoses\n",
"Installing collected packages: sacremoses, sentencepiece, transformers, sentence-transformers\n",
"Successfully installed sacremoses-0.0.38 sentence-transformers-0.2.5 sentencepiece-0.1.85 transformers-2.3.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eIAKz6KVEndZ",
"colab_type": "text"
},
"source": [
"## Load the BERT Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5IO_j2Ofv5pq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 80
},
"outputId": "717b9097-4c19-49b0-97cb-b8c31c84d3ec"
},
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"# Load the BERT model. Various models trained on Natural Language Inference (NLI) https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/nli-models.md and \n",
"# Semantic Textual Similarity are available https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/sts-models.md\n",
"\n",
"model = SentenceTransformer('bert-base-nli-mean-tokens')"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"100%|██████████| 405M/405M [00:18<00:00, 21.8MB/s]\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I7_Ib3ITEwgO",
"colab_type": "text"
},
"source": [
"## Setup a Corpus"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rEa50nefv7HY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "be7fcea8-15ac-4ab6-ebdf-55e507a7d496"
},
"source": [
"# A corpus is a list with documents split by sentences.\n",
"\n",
"sentences = ['Absence of sanity', \n",
" 'Lack of saneness',\n",
" 'A man is eating food.',\n",
" 'A man is eating a piece of bread.',\n",
" 'The girl is carrying a baby.',\n",
" 'A man is riding a horse.',\n",
" 'A woman is playing violin.',\n",
" 'Two men pushed carts through the woods.',\n",
" 'A man is riding a white horse on an enclosed ground.',\n",
" 'A monkey is playing drums.',\n",
" 'A cheetah is running behind its prey.']\n",
"\n",
"# Each sentence is encoded as a 1-D vector with 78 columns\n",
"sentence_embeddings = model.encode(sentences)\n",
"\n",
"print('Sample BERT embedding vector - length', len(sentence_embeddings[0]))\n",
"\n",
"print('Sample BERT embedding vector - note includes negative values', sentence_embeddings[0])"
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"Sample BERT embedding vector - length 768\n",
"Sample BERT embedding vector - note includes negative values [ 2.95403183e-01 2.91810960e-01 2.16480136e+00 2.20419839e-01\n",
" -1.30865965e-02 1.01950336e+00 1.51298153e+00 2.34131858e-01\n",
" 2.73058057e-01 1.35122865e-01 -1.11313331e+00 -1.25884697e-01\n",
" 1.45378590e-01 9.77708638e-01 1.39352250e+00 4.57705110e-01\n",
" -5.82131505e-01 -7.24940956e-01 -3.61734480e-01 -2.27514893e-01\n",
" 1.66630689e-02 2.04861969e-01 6.55133009e-01 -1.29376340e+00\n",
" -7.26099372e-01 -1.91135541e-01 -3.07211697e-01 -1.30278623e+00\n",
" -1.42963886e+00 5.67453261e-03 3.54811519e-01 4.83712614e-01\n",
" 6.65388048e-01 5.33848584e-01 6.40496492e-01 5.90408444e-01\n",
" 7.83850178e-02 -1.07759213e+00 -1.24676526e-01 -3.98406386e-01\n",
" 7.36313939e-01 5.28293252e-01 5.63290775e-01 4.14546162e-01\n",
" 4.49179649e-01 -9.58786383e-02 1.45424592e+00 -2.69145072e-01\n",
" -2.44059876e-01 -1.10387051e+00 -2.00923532e-01 -2.17437861e-03\n",
" 1.83387971e+00 1.06518459e+00 -5.11945903e-01 -1.11248529e+00\n",
" 5.59790373e-01 -5.89609563e-01 1.07621944e+00 7.49265671e-01\n",
" 4.32666481e-01 1.76307142e-01 -1.72125660e-02 1.19170345e-01\n",
" -8.37448120e-01 1.88445479e-01 -7.46789854e-03 1.70312412e-02\n",
" -4.29175943e-01 -5.72340369e-01 6.22608125e-01 -3.38023841e-01\n",
" 1.03712715e-01 2.42106587e-01 -9.82075334e-01 -2.63344377e-01\n",
" 4.15540874e-01 2.79155314e-01 -1.03487372e-01 7.32837498e-01\n",
" 2.78023511e-01 7.52160132e-01 8.08235168e-01 -4.87504005e-01\n",
" -7.64050961e-01 -7.08029419e-02 5.39888680e-01 2.24032193e-01\n",
" -1.66969836e+00 -2.01102495e-01 2.13853925e-01 1.08926618e+00\n",
" 5.79520047e-01 -3.32589626e-01 -9.75302875e-01 2.37982087e-02\n",
" -1.68780282e-01 2.63479292e-01 1.64727420e-01 2.03110695e-01\n",
" 8.12745094e-03 3.75387013e-01 4.53516781e-01 4.21480648e-02\n",
" 4.72985119e-01 3.53581369e-01 -1.55092627e-01 3.71253103e-01\n",
" 3.54824960e-01 5.11555187e-02 -9.36767399e-01 1.42013264e+00\n",
" 7.63064623e-01 -8.58592212e-01 -5.74659705e-01 2.43800841e-02\n",
" -1.39292669e+00 3.14162791e-01 -2.08897144e-02 4.01286274e-01\n",
" 8.56514335e-01 2.12013125e-01 1.93166375e-01 -8.65448266e-02\n",
" 3.66492569e-01 -5.34847438e-01 9.08364177e-01 -1.47459909e-01\n",
" -4.34119612e-01 1.38952568e-01 -2.85755783e-01 8.91503990e-01\n",
" -9.49378788e-01 -6.12362102e-02 -1.63513094e-01 2.15337463e-02\n",
" 9.33672637e-02 -1.83846831e-01 -7.05079511e-02 -9.50030535e-02\n",
" -7.78035462e-01 5.08425236e-01 -2.38957599e-01 1.18628956e-01\n",
" -2.22939938e-01 -4.55788314e-01 7.77875900e-01 5.58182418e-01\n",
" -6.46435559e-01 -3.76423895e-01 -9.94759917e-01 -2.24633545e-01\n",
" 2.56237477e-01 3.92683744e-01 -3.80205333e-01 -5.66550314e-01\n",
" 1.06151199e+00 6.39036715e-01 2.45125726e-01 -2.10276559e-01\n",
" -2.03608587e-01 5.58223248e-01 5.87122440e-02 4.09846961e-01\n",
" 9.85415801e-02 5.68384044e-02 -3.91428530e-01 2.61983126e-01\n",
" 1.85061485e-01 -7.71071732e-01 6.81889176e-01 -6.14751518e-01\n",
" -9.09989655e-01 -5.37011623e-01 -2.67421097e-01 6.41145855e-02\n",
" -1.47129208e-01 8.36788297e-01 2.73965627e-01 2.99047649e-01\n",
" 1.75318837e-01 -8.45175162e-02 1.32096362e+00 -1.32269931e+00\n",
" -8.96274865e-01 -7.21393704e-01 9.03015211e-02 -5.28525472e-01\n",
" -5.90585284e-02 -1.73461944e-01 -1.08447266e+00 -9.68358815e-01\n",
" 7.07607031e-01 -1.08554244e+00 7.90785030e-02 -1.23736322e-01\n",
" 1.04828671e-01 -4.38560009e-01 -9.64857712e-02 1.21586382e-01\n",
" 7.08318591e-01 -7.40591824e-01 -6.13734312e-02 5.97689033e-01\n",
" 5.36595702e-01 9.54430461e-01 5.82716949e-02 -7.02314198e-01\n",
" -5.21191716e-01 -5.31055272e-01 -6.87113464e-01 -2.28840038e-01\n",
" -1.25095084e-01 3.82895499e-01 9.40875232e-01 -1.53620028e+00\n",
" 3.08551461e-01 -1.07018912e+00 -9.22982693e-02 1.83376744e-01\n",
" 2.36132234e-01 -9.44777846e-01 3.63419414e-01 1.02015033e-01\n",
" -2.61932522e-01 1.31475377e+00 -8.66960660e-02 4.19597805e-01\n",
" -5.64485192e-01 -2.40892217e-01 9.56280604e-02 2.88403451e-01\n",
" -1.26559651e+00 9.68216509e-02 -6.86096549e-01 -9.73237216e-01\n",
" 4.98169512e-01 -6.72291875e-01 6.18023753e-01 -2.36299157e-01\n",
" -1.19642895e-02 2.23920777e-01 1.29233646e+00 9.04313385e-01\n",
" -6.94556832e-01 2.12914258e-01 1.41126677e-01 -1.09073651e+00\n",
" -2.61027157e-01 3.72416407e-01 5.65627873e-01 -7.63028026e-01\n",
" 2.54305422e-01 2.96842039e-01 6.39281869e-01 -3.68397057e-01\n",
" -1.50474444e-01 -2.06995815e-01 4.89718974e-01 -1.07649386e+00\n",
" -4.69321489e-01 5.64345777e-01 -8.54324996e-01 2.01307058e-01\n",
" -3.64126265e-01 -4.83238310e-01 -2.99138606e-01 -2.37890765e-01\n",
" -7.14595199e-01 3.88846397e-02 2.88149297e-01 -8.43381286e-01\n",
" -1.32528901e-01 3.20764244e-01 -5.64771175e-01 -6.50783896e-01\n",
" 1.23112619e+00 -3.79267991e-01 -5.36719501e-01 -1.18963256e-01\n",
" 5.57513535e-01 7.38234401e-01 -1.65188718e+00 4.81852919e-01\n",
" -9.48715389e-01 9.54266250e-01 6.12537324e-01 2.52186596e-01\n",
" -2.26502508e-01 -1.06311001e-01 -6.21922791e-01 7.16752052e-01\n",
" -3.86202455e-01 -7.54139066e-01 5.86001515e-01 -2.63744980e-01\n",
" 2.03416556e-01 -5.31250715e-01 -9.58163381e-01 7.24046588e-01\n",
" -6.24141991e-01 -1.32546872e-01 -7.14519739e-01 2.18565226e-01\n",
" -6.77284122e-01 -1.44099727e-01 -3.61371309e-01 8.55807483e-01\n",
" 4.30027187e-01 5.20425677e-01 -1.25135791e+00 8.63247830e-03\n",
" -8.13418508e-01 -1.89723641e-01 6.97539449e-01 -2.71261573e-01\n",
" -9.60209668e-01 -5.67387223e-01 -1.97102293e-01 -8.45597088e-01\n",
" 4.09596503e-01 4.65882063e-01 1.00616157e-01 1.79212242e-01\n",
" 3.18223983e-01 -1.94156840e-01 -2.31451839e-02 5.31907439e-01\n",
" 2.07857490e-01 -1.38748586e-01 -1.25763610e-01 -1.67028809e+00\n",
" 7.72697479e-02 3.29649597e-01 9.59621310e-01 9.71129060e-01\n",
" -8.59887302e-01 -7.97947407e-01 -4.50175434e-01 6.09351099e-01\n",
" -9.92598385e-02 -1.26160526e+00 -4.80852902e-01 -4.37261537e-02\n",
" 7.61089265e-01 -4.90454495e-01 -1.16635013e+00 -1.26680553e+00\n",
" 4.23249960e-01 -3.28321755e-01 1.13249719e-01 1.19094634e+00\n",
" -4.76775408e-01 2.77292818e-01 -6.62052512e-01 -6.26967430e-01\n",
" -3.28043163e-01 4.62100357e-01 4.22713980e-02 1.11225937e-02\n",
" 3.49018514e-01 3.41067076e-01 -2.20341116e-01 -8.18507969e-01\n",
" 6.06628835e-01 -6.28642023e-01 2.82574117e-01 4.81881082e-01\n",
" 1.42555833e+00 -1.52616878e-03 -5.17759204e-01 -2.46807709e-02\n",
" -5.39778233e-01 2.85306722e-01 -2.21777245e-01 -1.85012728e-01\n",
" 3.11429322e-01 -2.07644105e-01 -2.50214398e-01 -9.53149140e-01\n",
" 4.12619442e-01 3.92308757e-02 1.41120523e-01 -3.55187565e-01\n",
" 1.18068099e+00 -3.14618677e-01 4.79557097e-01 -4.70073700e-01\n",
" 1.77134991e-01 6.75846696e-01 9.84487832e-01 2.68136203e-01\n",
" -7.01508764e-03 -2.85394371e-01 -1.52261570e-01 -1.11626387e+00\n",
" 1.02338985e-01 5.77064395e-01 3.41381252e-01 4.19948399e-01\n",
" 4.03284848e-01 2.95105875e-01 9.24367547e-01 1.26249242e+00\n",
" 1.20447159e+00 9.47556794e-01 4.87714596e-02 4.61912721e-01\n",
" -8.18414986e-01 4.46733892e-01 6.60865068e-01 4.44435418e-01\n",
" -9.46657002e-01 3.16685140e-01 -5.51689267e-01 -1.82610482e-01\n",
" -1.46252900e-01 -1.01644254e+00 7.33198643e-01 1.08742762e+00\n",
" -6.23244762e-01 -6.87371850e-01 -1.63677886e-01 2.05873415e-01\n",
" -1.28120184e-04 1.58518422e+00 -2.15279460e-01 -3.93540442e-01\n",
" -2.25850105e-01 -4.96896729e-02 -3.16225171e-01 -2.08421387e-02\n",
" -7.36599624e-01 5.96843779e-01 -6.45120263e-01 -5.47669291e-01\n",
" 1.46693140e-01 -1.00796354e+00 2.46947721e-01 -1.16965592e-01\n",
" 1.06088769e+00 1.71407703e-02 1.60032183e-01 -1.61021389e-02\n",
" -4.65260029e-01 -3.54801744e-01 -7.89501548e-01 -5.38766980e-01\n",
" 4.41756576e-01 -1.14354692e-01 2.15809137e-01 3.88403803e-01\n",
" -5.96565425e-01 6.83732331e-01 1.11033058e+00 8.61530006e-01\n",
" -1.48576707e-01 1.11748576e+00 3.42844248e-01 -8.33445191e-02\n",
" 3.36354747e-02 3.10089946e-01 -1.54736900e+00 -4.20448691e-01\n",
" 8.92356411e-02 -1.90214276e-01 -7.60470182e-02 -9.25956726e-01\n",
" -2.31246114e-01 3.78083646e-01 -9.29461479e-01 -2.05470651e-01\n",
" -1.10455379e-01 2.51403034e-01 8.28987211e-02 4.65330124e-01\n",
" -1.59002924e+00 1.67836659e-02 -1.03855811e-01 4.42655265e-01\n",
" 5.59934914e-01 4.51579571e-01 -7.67547116e-02 -4.73889589e-01\n",
" -1.16150534e+00 2.94715017e-01 2.40399241e-01 3.64251673e-01\n",
" 5.29475212e-01 2.42404942e-03 1.62730403e-02 -1.22207522e-01\n",
" -1.02022350e+00 -1.01030040e+00 -3.02590966e-01 -2.43045136e-01\n",
" -7.07274854e-01 1.77668497e-01 2.10999519e-01 7.66401470e-01\n",
" 3.03304732e-01 -9.95232444e-03 -5.04402518e-01 -5.96487880e-01\n",
" 4.28530633e-01 7.56300837e-02 1.18148172e+00 2.74924375e-02\n",
" 1.09997904e+00 1.69377640e-01 1.15801789e-01 7.44199634e-01\n",
" -1.96190163e-01 -5.51071048e-01 -3.15140635e-01 -7.55082488e-01\n",
" -5.86989045e-01 4.44423974e-01 -3.26272547e-01 -5.73182702e-01\n",
" -4.43318665e-01 3.59878451e-01 -4.29275706e-02 8.31856728e-01\n",
" 6.22972071e-01 -2.17173025e-01 -9.20350194e-01 9.19905782e-01\n",
" -7.29426295e-02 5.15981987e-02 -9.65983048e-02 -1.27199262e-01\n",
" -3.96581441e-01 4.10287194e-02 2.23799735e-01 2.97161460e-01\n",
" 2.92265236e-01 -5.09368777e-01 -5.68877220e-01 4.27781045e-01\n",
" 4.16459739e-01 4.65236485e-01 1.06364763e+00 -6.52427852e-01\n",
" -8.28918576e-01 3.47210914e-02 3.42104435e-01 1.88819021e-01\n",
" 5.04421771e-01 -2.34061807e-01 1.80037186e-01 4.29086864e-01\n",
" 2.30355531e-01 -2.83437073e-01 2.34264821e-01 -5.11488378e-01\n",
" 4.76539314e-01 1.56475782e-01 -1.47166222e-01 -1.02309191e+00\n",
" -5.59814811e-01 -3.14807624e-01 1.36164622e-02 2.48336405e-01\n",
" -5.14447033e-01 -9.95983899e-01 2.44898498e-02 -2.81420350e-02\n",
" 2.93588758e-01 -6.17680728e-01 2.71538556e-01 -6.88656926e-01\n",
" 5.12018561e-01 -5.74781522e-02 -2.89483726e-01 -2.38752604e-01\n",
" 5.62780797e-01 -1.02241468e+00 -6.51534200e-01 2.27026194e-01\n",
" 3.18955034e-01 8.17963704e-02 -5.93574420e-02 -1.17566586e+00\n",
" -4.10225168e-02 -2.96685398e-01 5.82193732e-01 -7.70743072e-01\n",
" -7.15863481e-02 4.76808727e-01 3.29380572e-01 -5.99579334e-01\n",
" 5.35179069e-03 -3.21955502e-01 1.17016971e+00 -1.79935068e-01\n",
" -5.21227658e-01 -2.49460310e-01 4.84306246e-01 -2.64465153e-01\n",
" -2.03607529e-01 3.04992557e-01 -1.25715005e+00 7.55809188e-01\n",
" -6.03417039e-01 8.71127471e-02 -1.06220879e-01 -5.40654659e-01\n",
" -5.11184096e-01 -4.32208218e-02 6.84743106e-01 -9.05072093e-01\n",
" 4.51821312e-02 -2.63025343e-01 -7.73667634e-01 -7.93841958e-01\n",
" -3.95872891e-01 2.98925154e-02 1.12445605e+00 8.02642941e-01\n",
" -3.85727108e-01 -7.29575157e-01 6.11615479e-01 -1.75381213e-01\n",
" -1.70760036e-01 -1.05162525e+00 1.05051017e+00 -6.94110632e-01\n",
" 6.68203294e-01 -1.40023530e-01 4.58647013e-01 5.88184372e-02\n",
" 3.39238942e-01 2.20731258e-01 -4.91491079e-01 8.13486099e-01\n",
" 1.32252109e+00 1.83042541e-01 2.82913625e-01 -4.13827658e-01\n",
" 2.75093645e-01 -1.03679979e+00 4.73541558e-01 6.67740345e-01\n",
" -7.56594315e-02 9.01946127e-02 -8.08835685e-01 -1.11946911e-02\n",
" -4.65080068e-02 7.15481877e-01 -1.27312958e-01 -6.92539990e-01\n",
" -7.13321805e-01 2.09888488e-01 -3.56735557e-01 -2.66319543e-01\n",
" -4.60457742e-01 7.38253538e-03 -8.48066330e-01 -8.13465118e-01\n",
" 2.13946342e-01 -5.31533599e-01 -8.76827091e-02 3.71386945e-01\n",
" 4.14129198e-01 8.61194253e-01 8.82488936e-02 3.36408019e-01\n",
" 8.24467689e-02 1.00937831e+00 1.82786733e-01 4.32668254e-02\n",
" 1.73484877e-01 -1.22647047e+00 -1.27435565e-01 -7.31793106e-01\n",
" 9.08650219e-01 -8.02064896e-01 2.60243356e-01 8.27428102e-02\n",
" 1.52959555e-01 1.81248933e-01 -7.75716007e-01 3.96855772e-01\n",
" 6.91959858e-01 -2.00028941e-01 -1.84013590e-01 7.82088488e-02\n",
" -4.08993483e-01 -3.83011028e-02 8.61753106e-01 -2.74388075e-01\n",
" -9.74847972e-02 -1.02602720e+00 -1.26833722e-01 1.11962497e+00\n",
" 6.53884351e-01 1.42315373e-01 -4.75912750e-01 -3.21260281e-02\n",
" -2.65989333e-01 -1.16977885e-01 -5.93443096e-01 6.56192124e-01\n",
" -5.73520660e-01 -7.07489491e-01 -7.29676127e-01 6.20043557e-03\n",
" 6.79842353e-01 4.15460885e-01 -3.41880143e-01 2.28170443e+00\n",
" -1.02978444e+00 4.97357070e-01 -4.68666613e-01 1.19644988e+00\n",
" -6.38581872e-01 -1.35883436e-01 -7.98879743e-01 -5.45745015e-01\n",
" -4.36387926e-01 -4.08294976e-01 -2.27469355e-01 -1.78951710e-01\n",
" 5.60200289e-02 2.35111445e-01 -8.29315543e-01 1.66124683e-02\n",
" -8.07581782e-01 -6.34335399e-01 3.33200425e-01 2.99568564e-01\n",
" -8.49243045e-01 5.52447960e-02 8.02986920e-01 4.63941514e-01\n",
" 6.54758275e-01 -5.37097454e-04 -2.04843923e-01 9.76368427e-01\n",
" -3.53015751e-01 9.11225975e-01 7.22659290e-01 -3.19766104e-01\n",
" 2.27475837e-01 -2.76821464e-01 4.26936716e-01 3.22493613e-01\n",
" -4.22099531e-01 2.96951920e-01 6.81850255e-01 1.47002387e+00\n",
" -2.75375154e-02 -7.22703218e-01 3.05495150e-02 -5.42968288e-02\n",
" -5.06122649e-01 4.57347259e-02 8.26341063e-02 4.99086320e-01\n",
" 9.00193095e-01 -8.83195698e-01 -9.96071637e-01 -2.98155367e-01\n",
" -4.14106518e-01 -5.26974797e-01 -5.91103375e-01 -2.92363405e-01]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QLZaLMDuGhEh",
"colab_type": "text"
},
"source": [
"## Perform Semantic Search"
]
},
{
"cell_type": "code",
"metadata": {
"id": "D6wkgQmywt0k",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 221
},
"outputId": "2e038dec-abbe-4628-94a7-dc00d38d8a03"
},
"source": [
"#@title Sematic Search Form\n",
"\n",
"# code adapted from https://github.com/UKPLab/sentence-transformers/blob/master/examples/application_semantic_search.py\n",
"\n",
"query = 'Nobody has sane thoughts' #@param {type: 'string'}\n",
"\n",
"queries = [query]\n",
"query_embeddings = model.encode(queries)\n",
"\n",
"# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity\n",
"number_top_matches = 3 #@param {type: \"number\"}\n",
"\n",
"print(\"Semantic Search Results\")\n",
"\n",
"for query, query_embedding in zip(queries, query_embeddings):\n",
" distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, \"cosine\")[0]\n",
"\n",
" results = zip(range(len(distances)), distances)\n",
" results = sorted(results, key=lambda x: x[1])\n",
"\n",
" print(\"\\n\\n======================\\n\\n\")\n",
" print(\"Query:\", query)\n",
" print(\"\\nTop 5 most similar sentences in corpus:\")\n",
"\n",
" for idx, distance in results[0:number_top_matches]:\n",
" print(sentences[idx].strip(), \"(Cosine Score: %.4f)\" % (1-distance))"
],
"execution_count": 27,
"outputs": [
{
"output_type": "stream",
"text": [
"Semantic Search Results\n",
"\n",
"\n",
"======================\n",
"\n",
"\n",
"Query: Nobody has sane thoughts\n",
"\n",
"Top 5 most similar sentences in corpus:\n",
"Lack of saneness (Cosine Score: 0.8958)\n",
"Absence of sanity (Cosine Score: 0.8744)\n",
"A man is riding a horse. (Cosine Score: 0.1705)\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment