SentenceTransformerModel.train_model

opensearch_py_ml.ml_models.SentenceTransformerModel.train_model(self, train_examples: List[List[str]], model_id: str | None = None, output_model_name: str | None = None, learning_rate: float = 2e-05, num_epochs: int = 10, batch_size: int = 32, verbose: bool = False, num_gpu: int = 0, percentile: float = 95)

Takes in training data and a sentence transformer url to train a custom semantic search model

Parameters:
  • train_examples (List of strings in another list) – required, input for the sentence transformer model training

  • model_id (string) – [optional] the url to download sentence transformer model, if None, default as ‘sentence-transformers/msmarco-distilbert-base-tas-b’

  • output_model_name (string) – optional,the name of the trained custom model. If None, default as model_id + ‘.pt’

  • learning_rate (float) – optional, learning rate to train model, default is 2e-5

  • num_epochs (int) – optional, number of epochs to train model, default is 10

  • batch_size (int) – optional, batch size for training, default is 32

  • verbose (bool) – optional, use plotting to plot the training progress and printing more logs. Default as false

  • num_gpu (int) – Number of gpu will be used for training. Default 0

  • percentile (float) – To save memory while training we truncate all passages beyond a certain max_length. Most middle-sized transformers have a max length limit of 512 tokens. However, certain corpora can have shorter documents. We find the word length of all documents, sort them in increasing order and take the max length of {percentile}% of the documents. Default is 95%

Returns:

the torch script format trained model.

Return type:

.pt file