Skip to contents

Represents models based on MPNet.

Value

Does return a new object of this class.

References

Song,K., Tan, X., Qin, T., Lu, J. & Liu, T.-Y. (2020). MPNet: Masked and Permuted Pre-training for Language Understanding. doi:10.48550/arXiv.2004.09297

Methods

Inherited methods


Method configure()

Configures a new object of this class.

Usage

BaseModelMPNet$configure(
  tokenizer,
  max_position_embeddings = 512L,
  hidden_size = 768L,
  num_hidden_layers = 12L,
  num_attention_heads = 12L,
  intermediate_size = 3072L,
  hidden_act = "GELU",
  hidden_dropout_prob = 0.1,
  attention_probs_dropout_prob = 0.1
)

Arguments

tokenizer

TokenizerBase Tokenizer for the model.

max_position_embeddings

int Number of maximum position embeddings. This parameter also determines the maximum length of a sequence which can be processed with the model. Allowed values: 10 <= x <= 4048

hidden_size

int Number of neurons in each layer. This parameter determines the dimensionality of the resulting text embedding. Allowed values: 1 <= x <= 2048

num_hidden_layers

int Number of hidden layers. Allowed values: 1 <= x

num_attention_heads

int determining the number of attention heads for a self-attention layer. Only relevant if attention_type='multihead' Allowed values: 0 <= x

intermediate_size

int determining the size of the projection layer within a each transformer encoder. Allowed values: 1 <= x

hidden_act

string Name of the activation function. Allowed values: 'GELU', 'relu', 'silu', 'gelu_new'

hidden_dropout_prob

double Ratio of dropout. Allowed values: 0 <= x <= 0.6

attention_probs_dropout_prob

double Ratio of dropout for attention probabilities. Allowed values: 0 <= x <= 0.6

Returns

Does nothing return.


Method train()

Traines a BaseModel

Usage

BaseModelMPNet$train(
  text_dataset,
  p_mask = 0.15,
  p_perm = 0.15,
  whole_word = TRUE,
  val_size = 0.1,
  n_epoch = 1L,
  batch_size = 12L,
  max_sequence_length = 250L,
  full_sequences_only = FALSE,
  min_seq_len = 50L,
  learning_rate = 0.003,
  sustain_track = FALSE,
  sustain_iso_code = NULL,
  sustain_region = NULL,
  sustain_interval = 15L,
  sustain_log_level = "warning",
  trace = TRUE,
  pytorch_trace = 1L,
  log_dir = NULL,
  log_write_interval = 2L
)

Arguments

text_dataset
p_mask
p_perm
whole_word
val_size
n_epoch
batch_size
max_sequence_length
full_sequences_only
min_seq_len
learning_rate
sustain_track
sustain_iso_code
sustain_region
sustain_interval
sustain_log_level
trace
pytorch_trace
log_dir
log_write_interval

Returns

Does nothing return.


Method clone()

The objects of this class are cloneable with this method.

Usage

BaseModelMPNet$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.