Demo Notebook for Sentence Transformer Model Training, Saving and Uploading to OpenSearch

Download notebook


This notebook introduces the technique of synthetic data generation and how it can be used to obtain a deep learning model for Search that is custom built for a given set of documents.

Deep learning models are very powerful and have been shown to improve state of the art in several disciplines and tasks. However, they need a lot of labelled training data. Such data is often hard to obtain. In this notebook, we show how pre-trained large language models can be used to circumvent this issue.

We focus on the task of passage retrieval i.e the corpus consists of passages which is searched at run-time given a user query. This search can be performed by transformers such as BERT as long as BERT is trained on a labelled dataset that consists of pairs such as (queries, relevant passage). Such a BERT model can be used for semantic search.

Synthetic query generation

In the absence of such labelled data we provide a synthetic query generator (SQG) model that can be used to create synthetic queries given a passage. The SQG model is a large transformer model that has been trained to generate human like queries given a passage. Thus it can be used to create a labelled dataset of (synthetic queries, passage). A BERT model can be trained on this synthetic data and used for semantic search. In fact, we find that such synthetically trained models beat the current state-of-the-art models. Note that resulting BERT model is a customized model since it has been trained on a specific corpus (and corresponding synthetic queries).

This notebook provides an end-to-end guide for users to generate synthetic queries and fine-tune a sentence transformer model on it using opensearch_py_ml. It consists of the following steps,

Step 1: Import packages and set up client

Step 2: Import the data/passages for synthetic query generation

Step 3: Generate Synthetic Queries

Step 4: Read synthetic queries and train/fine-tune model using a hugging face sentence transformer model

Step 5: Upload the model to OpenSearch cluster

Steps 3 and 4 are compute intensive step, and we recommend running it on a machine with 4 or more GPUS such as the EC2 p3.8xlarge or p3.16xlarge.

Step 1: Import packages, set up client and define helper functions

Install required packages for opensearch_py_ml.sentence_transformer_model Install opensearchpy and opensearch-py-ml through pypi script is released with the Synthetic Query Generation model.

Please refer to proper install torch based on your environment setting.

Please install the following packages from the terminal if you haven’t already. They can be also installed from the notebook by uncommenting the line and execute.

# Download for Generate Synthetic Queries

import urllib.request
urllib.request.urlretrieve("", "")
('', <http.client.HTTPMessage at 0x10bdbe940>)
# !pip install pandas matplotlib numpy torch accelerate sentence_transformers tqdm transformers opensearch-py opensearch-py-ml detoxify datasets
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings("ignore", message="Unverified HTTPS request")
import opensearch_py_ml as oml
from opensearchpy import OpenSearch
import generate
from generate import Synthetic_Query_Generation
from opensearch_py_ml.ml_models import SentenceTransformerModel
import boto3, json
import pandas as pd, numpy as np
from datasets import load_dataset
import gc, torch
# import mlcommon to later upload the model to OpenSearch Cluster
from opensearch_py_ml.ml_commons import MLCommonClient
CLUSTER_URL = 'https://localhost:9200'
def get_os_client(cluster_url = CLUSTER_URL,
    Get OpenSearch client
    :param cluster_url: cluster URL like
    :return: OpenSearch client
    client = OpenSearch(
        http_auth=(username, password),
    return client
client = get_os_client()
def myselect(x):
    if max(x["passages"]["is_selected"]) == 1:
        return x["passages"]["passage_text"][np.argmax(x["passages"]["is_selected"])]
    return "-1"

Step 2: Import the data/passages for synthetic query generation

There are three supported options to read datasets :

  • Option 1: read from a local data folder in jsonl file

  • Option 2: read from a list of passages

  • Option 3: read from OpenSearch client by index_name

For the purpose of this notebook we will demonstrate option 2: read from a list of passages.

We take the MS Marco dataset of passages as our example dataset.

2.1) Load the data and convert into a pandas dataframe

[ ]:
dataset = load_dataset("ms_marco","v1.1")
df = pd.DataFrame.from_dict(dataset["validation"])
[ ]:
df["passage"] = df.apply(lambda x: myselect(x), axis = 1)
df = df[["query","passage"]][df.passage != "-1"]
[ ]:
# Setting print options to display full columns

pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', None)
pd.set_option('max_colwidth', None)

The dataset looks like,

[ ]:

The MS Marco dataset has real queries for passages but we will pretend that it does not and generate synthetic queries for each passage

2.2) Convert the data into a list of strings and instantiate an object of the class Synthetic_Query_Generation

[ ]:
sample_passages = list(df.passage.values)
[ ]:
ss = Synthetic_Query_Generation(sentences = sample_passages[:8])

Step 3: Generate synthetic queries

[ ]:
three_step_query = ss.generate_synthetic_queries(num_machines = 1,
                                                 tokenize_data = True,
                                                 tokenizer_max_length =  300,
                                                 total_queries = 10,
                                                 numseq = 5,
                                                 num_gpu = 0,
                                                 toxic_cutoff = 0.01,
                                                 tokens_to_word_ratio = 0.6)

A lot of actions are being executed in the above cell. We elaborate them step by step,

1) Convert the data into a form that can be consumed by the Synthetic query generator (SQG) model. This amounts to tokenizing the data using a tokenizer. The SQG model is a fine-tuned version of the GPT-XL model and the tokenizer is the GPT tokenizer.

2) The tokenizer has a max input length of 512 tokens. Every passage is tokenized with the special tokens <|startoftext|> and QRY: appended to the beginning and the end of every passage respectively. Note that tokenization is a time intensive process and the script saves the tokenized data after the first pass. We recommend setting tokenize_data = False subsequently.

3) Load the SQG model i.e. 1.5B parameter GPT2-XL model that has been trained to ask questions given passages. This model has been made publicly available and can be found here:

4) Once the model has been loaded and the data has been tokenized, the model starts the process of query generation. "total_queries" is number of synthetic queries generated for every passage and "numseq" is the number of queries that are generated by a model at a given time. Ideally total_queries = numseq, but this can lead to out of memory issues. So set numseq to an integer that is around 10 or less, and is a divisor of total_queries.

5) tokens_to_word_ratio is a float variable that is used to switch between length of a document in tokens vs. in words. It is used when truncating documents during the tokenization phase. Most words are split in to one or more tokens. A document that has a length of 300 tokens might only be 200 words long. This ratio of 200/300 = 2/3 = 0.667 is the tokens_to_word_ratio. For passages from a dataset such as Wikipedia this ratio is around 0.65 to 0.7, but for domain specific datasets this ratio could be as small as 0.5.

6) The script also requires to know the number of GPUs and the number of machines/nodes that it can use. Since we are using a single node instance with no GPUs we pass 0 and 1 to the function respectively. Our recommended setting is to use 1 machine/node with at least 4 (ideally 8) GPUs.

7) The script now begins to generate queries and displays a progress bar. We create total_queries per passage. Empirically we find that generating more queries leads to better performance but there are diminishing returns since the total inference time increases with total_queries.

8) After generating the queries, the function uses a publicly available package called Detoxify to remove inappropriate queries from the dataset. "toxic_cutoff" is a float. The script rejects all queries that have a toxicity score greater than toxic_cutoff

9) Finally, the synthetic queries along with their corresponding passages are saved in a zipped file in the current working directory.

Note – Please restart the kernel and rerun it if the notebook gives CUDA related errors.

This is how the sample queries look like,

[ ]:
# initiate SentenceTransformerModel object

custom_model = SentenceTransformerModel(folder_path="/Volumes/workplace/upload_content/model_files/", overwrite = True)

df = custom_model.read_queries(read_path = '/Volumes/workplace/upload_content/', overwrite = True)


Step 4: Read synthetic queries and train/fine-tune a hugging face sentence transformer model on synthetic data

With a synthetic queries zip file, users can fine tune a sentence transformer model.

The SentenceTransformerModel class will inititate an object for training, exporting and configuring the model. Plese visit the SentenceTransformerModel for API Reference .

The train function will import synthestic queries, load sentence transformer example and train the model using a hugging face sentence transformer model. Plese visit the SentenceTransformerModel.train for API Reference .

[ ]:
# clean up cache before training to free up spaces
import gc, torch


[ ]:

training = custom_model.train(read_path = '/Volumes/workplace/upload_content/', output_model_name = '', zip_file_name= '', overwrite = True, num_epochs = 10, verbose = False)

Following are some important points about the training cell executed above,

  1. The input to the training script consists of (query, passage) pairs. The model is trained to maximize the dot product between relevant queries and passages while at the same time minimize the dot product between queries and irrelevant passages. This is also known as contrastive learning. We implement this using in-batch negatives and a symmetric loss as mentioned below.

  2. To utilize the power of GPUs we collect training samples into a batch before sending for model training. Each batch contains B number of randomly selected training samples (q, p). Thus within a batch each query has one relevant passage and B-1 irrelevant passages. Similarly for every passage there’s one relevant query and B-1 irrelevant queries. For every given relevant query and passage pair we minimize the following expression, called the loss,

  3. For a given batch B, the loss is defined as loss = C(q, p) + C(p, q) where \(C(q, p) = - \sum_{i=1}^{i=B} \log \left( \frac{exp(q_i \cdot p_i)}{\sum_{j=1} ^{B} exp(q_i \cdot p_j)}\right)\)

  4. The model truncates documents beyond 512 tokens. If the corpus contains documents that are shorter than 512 tokens the model max length can be adjusted to that number. Shorter sequences take less memory and therefore allow for bigger batch sizes. The max length can be adjusted by the “percentile” argument.

  5. We use a batch size of 32 per device. Larger batch sizes lead to more in-batch negative samples and lead to better performance but unfortunately they also lead to out of memory issues. Shorter sequences use less memory, so if the document corpus is short feel free to experiment with larger batch sizes.

  6. The model is trained using the AdamW optimizer for 10 epochs with a learning rate of 2e-5 and a scheduler with linear schedule with warmup steps = 10,000

Step 5: Upload the model to OpenSearch cluster

After generated a model zip file, the users will need to describe model configuration in a ml-commons_model_config.json file. The make_model_config_json function in sentencetransformermodel class will parse the config file from hugging-face config.son file. If users would like to use a different config than the pre-trained sentence transformer, make_model_config_json function provide arguuments to change the configuration content and generated a ml-commons_model_config.json file. Plese visit the SentenceTransformerModel.make_model_config_json for API Reference .

In general, the ml common client supports uploading sentence transformer models. With a zip file contains model in Torch Script format, and a configuration file for tokenizers in json format, the upload_model function connects to opensearch through ml client and upload the model. Plese visit the MLCommonClient.upload_model for API Reference.

[ ]:
#users will need to prepare a ml-commons_model_config.json file to config the model, including model name ..
#this is a helpful function in py-ml.sentence_transformer_model to generate ml-commons_model_config.json file
#connect to ml_common client with OpenSearch client
import opensearch_py_ml as oml
from opensearch_py_ml.ml_commons import MLCommonClient
ml_client = MLCommonClient(client)
# upload model to OpenSearch cluster, using model zip file path and ml-commons_model_config.json file generated above

model_path = '/Volumes/workplace/upload_content/'
model_config_path = '/Volumes/workplace/upload_content/model_config.json'
ml_client.upload_model( model_path, model_config_path, isVerbose=True)
Total number of chunks 10
Sha1 value of the model file:  61fd5a1425960681da49d084dca0e52fd0fabcc0f2e1c4d57c4e20e193bde483
Model meta data was created successfully. Model Id:  lGFG9IUBTo3f8n5R8nM6
uploading chunk 1 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 2 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 3 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 4 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 5 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 6 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 7 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 8 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 9 of 10
Model id: {'status': 'Uploaded'}
uploading chunk 10 of 10
Model id: {'status': 'Uploaded'}
Model uploaded successfully