Represents models based on MPNet.
References
Song,K., Tan, X., Qin, T., Lu, J. & Liu, T.-Y. (2020). MPNet: Masked and Permuted Pre-training for Language Understanding. doi:10.48550/arXiv.2004.09297
See also
Other Base Model:
BaseModelBert,
BaseModelDebertaV2,
BaseModelFunnel,
BaseModelModernBert,
BaseModelRoberta
Super classes
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> aifeducation::BaseModelCore -> BaseModelMPNet
Methods
Inherited methods
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()aifeducation::BaseModelCore$calc_flops_architecture_based()aifeducation::BaseModelCore$count_parameter()aifeducation::BaseModelCore$create_from_hf()aifeducation::BaseModelCore$estimate_sustainability_inference_fill_mask()aifeducation::BaseModelCore$fill_mask()aifeducation::BaseModelCore$get_final_size()aifeducation::BaseModelCore$get_flops_estimates()aifeducation::BaseModelCore$get_model()aifeducation::BaseModelCore$get_model_type()aifeducation::BaseModelCore$get_n_layers()aifeducation::BaseModelCore$get_special_tokens()aifeducation::BaseModelCore$get_tokenizer_statistics()aifeducation::BaseModelCore$load_from_disk()aifeducation::BaseModelCore$plot_training_history()aifeducation::BaseModelCore$save()aifeducation::BaseModelCore$set_publication_info()
Method configure()
Configures a new object of this class. Please ensure that your chosen configuration comply with the following guidelines:
hidden_size is a multiple of num_attention_heads.
Usage
BaseModelMPNet$configure(
tokenizer,
max_position_embeddings = 512L,
hidden_size = 768L,
num_hidden_layers = 12L,
num_attention_heads = 12L,
intermediate_size = 3072L,
hidden_act = "GELU",
hidden_dropout_prob = 0.1,
attention_probs_dropout_prob = 0.1
)Arguments
tokenizerTokenizerBaseTokenizer for the model.max_position_embeddingsintNumber of maximum position embeddings. This parameter also determines the maximum length of a sequence which can be processed with the model. Allowed values: \(10 <= x <= 4048\)hidden_sizeintNumber of neurons in each layer. This parameter determines the dimensionality of the resulting text embedding. Allowed values: \(1 <= x <= 2048\)num_hidden_layersintNumber of hidden layers. Allowed values: \(1 <= x \)num_attention_headsintdetermining the number of attention heads for a self-attention layer. Only relevant ifattention_type='multihead'Allowed values: \(0 <= x \)intermediate_sizeintdetermining the size of the projection layer within a each transformer encoder. Allowed values: \(1 <= x \)hidden_actstringName of the activation function. Allowed values: 'GELU', 'relu', 'silu', 'gelu_new'hidden_dropout_probdoubleRatio of dropout. Allowed values: \(0 <= x <= 0.6\)attention_probs_dropout_probdoubleRatio of dropout for attention probabilities. Allowed values: \(0 <= x <= 0.6\)
Method train()
Traines a BaseModel
Usage
BaseModelMPNet$train(
text_dataset,
p_mask = 0.15,
p_perm = 0.15,
whole_word = TRUE,
val_size = 0.1,
n_epoch = 1L,
batch_size = 12L,
max_sequence_length = 250L,
full_sequences_only = FALSE,
min_seq_len = 50L,
learning_rate = 0.003,
sustain_track = FALSE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
trace = TRUE,
pytorch_trace = 1L,
log_dir = NULL,
log_write_interval = 2L
)Arguments
text_datasetLargeDataSetForTextLargeDataSetForText Object storing textual data.p_maskdoubleRatio that determines the number of tokens used for masking. Allowed values: \(0.05 <= x <= 0.6\)p_permdoubleRatio that determines the number of tokens used for permutation. Allowed values: \(0.05 <= x <= 0.6\)whole_wordbool*TRUE: whole word masking should be applied. Only relevant if aWordPieceTokenizeris used.FALSE: token masking is used.
val_sizedoublebetween 0 and 1, indicating the proportion of cases which should be used for the validation sample during the estimation of the model. The remaining cases are part of the training data. Allowed values: \(0 < x < 1\)n_epochintNumber of training epochs. Allowed values: \(1 <= x \)batch_sizeintSize of the batches for training. Allowed values: \(1 <= x \)max_sequence_lengthintMaximal number of tokens for every sequence. Allowed values: \(20 <= x \)full_sequences_onlyboolTRUEfor using only chunks with a sequence length equal tochunk_size.min_seq_lenintOnly relevant iffull_sequences_only = FALSE. Value determines the minimal sequence length included in training process. Allowed values: \(10 <= x \)learning_ratedoubleInitial learning rate for the training. Allowed values: \(0 < x <= 1\)sustain_trackboolIfTRUEenergy consumption is tracked during training via the python library 'codecarbon'.sustain_iso_codestringISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: anysustain_regionstringRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: anysustain_intervalintInterval in seconds for measuring power usage. Allowed values: \(1 <= x \)sustain_log_levelstringLevel for printing information to the console. Allowed values: 'debug', 'info', 'warning', 'error', 'critical'traceboolTRUEif information about the estimation phase should be printed to the console.pytorch_traceintml_trace=0does not print any information about the training process from pytorch on the console. Allowed values: \(0 <= x <= 1\)log_dirstringPath to the directory where the log files should be saved. If no logging is desired set this argument toNULL. Allowed values: anylog_write_intervalintTime in seconds determining the interval in which the logger should try to update the log files. Only relevant iflog_diris notNULL. Allowed values: \(1 <= x \)