SentenceTransformerModel.train
- opensearch_py_ml.ml_models.SentenceTransformerModel.train(self, read_path: str, overwrite: bool = False, output_model_name: str | None = None, zip_file_name: str | None = None, compute_environment: str | None = None, num_machines: int = 1, num_gpu: int = 0, learning_rate: float = 2e-05, num_epochs: int = 10, batch_size: int = 32, verbose: bool = False, percentile: float = 95) None
Read the synthetic queries and use it to fine tune/train (and save) a sentence transformer model.
Parameters
- param read_path:
required, path to the zipped file that contains generated queries, if None, raise exception. the zipped file should contain pickled file in list of dictionary format with key named as ‘query’, ‘probability’ and ‘passages’. For example: [{‘query’:q1,’probability’: p1,’passages’: pa1}, …]. ‘probability’ is not required for training purpose
- type read_path:
string
- param overwrite:
optional, synthetic_queries/ folder in current directory is to store unzip queries files. Default to set overwrite as false and if the folder is not empty, raise exception to recommend users to either clean up folder or enable overwriting is True
- type overwrite:
bool
- param output_model_name:
the name of the trained custom model. If None, default as model_id + ‘.pt’
- type output_model_name:
string
- param zip_file_name:
Optional, file name for zip file. if None, default as model_id + ‘.zip’
- type zip_file_name:
string
- param compute_environment:
optional, compute environment type to run model, if None, default using LOCAL_MACHINE
- type compute_environment:
string
- param num_machines:
optional, number of machine to run model , if None, default using 1
- type num_machines:
int
- param num_gpu:
optional, number of gpus to run model , if None, default to 0. If number of gpus > 1, use HuggingFace accelerate to launch distributed training
- param learning_rate:
optional, learning rate to train model, default is 2e-5
- type learning_rate:
float
- param num_epochs:
optional, number of epochs to train model, default is 10
- type num_epochs:
int
- param batch_size:
optional, batch size for training, default is 32
- type batch_size:
int
- param verbose:
optional, use plotting to plot the training progress. Default as false
- type verbose:
bool
- param percentile:
we find the max length of {percentile}% of the documents. Default is 95% Since this length is measured in terms of words and not tokens we multiply it by 1.4 to approximate the fact that 1 word in the english vocabulary roughly translates to 1.3 to 1.5 tokens
- type percentile:
float
Returns
- return:
no return value expected
- rtype:
None