Function for training and fine-tuning a RoBERTa model
Source:R/transformer_roberta.R
train_tune_roberta_model.Rd
This function can be used to train or fine-tune a transformer based on RoBERTa architecture with the help of the python libraries 'transformers', 'datasets', and 'tokenizers'.
Usage
train_tune_roberta_model(
ml_framework = aifeducation_config$get_framework(),
output_dir,
model_dir_path,
raw_texts,
p_mask = 0.15,
val_size = 0.1,
n_epoch = 1,
batch_size = 12,
chunk_size = 250,
full_sequences_only = FALSE,
min_seq_len = 50,
learning_rate = 0.03,
n_workers = 1,
multi_process = FALSE,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
trace = TRUE,
keras_trace = 1,
pytorch_trace = 1,
pytorch_safetensors = TRUE
)
Arguments
- ml_framework
string
Framework to use for training and inference.ml_framework="tensorflow"
for 'tensorflow' andml_framework="pytorch"
for 'pytorch'.- output_dir
string
Path to the directory where the final model should be saved. If the directory does not exist, it will be created.- model_dir_path
string
Path to the directory where the original model is stored.- raw_texts
vector
containing the raw texts for training.- p_mask
double
Ratio determining the number of words/tokens for masking.- val_size
double
Ratio determining the amount of token chunks used for validation.- n_epoch
int
Number of epochs for training.- batch_size
int
Size of batches.- chunk_size
int
Size of every chunk for training.- full_sequences_only
bool
TRUE
for using only chunks with a sequence length equal tochunk_size
.- min_seq_len
int
Only relevant iffull_sequences_only=FALSE
. Value determines the minimal sequence length for inclusion in training process.- learning_rate
bool
Learning rate for adam optimizer.- n_workers
int
Number of workers. Only relevant ifml_framework="tensorflow"
.- multi_process
bool
TRUE
if multiple processes should be activated. Only relevant ifml_framework="tensorflow"
.- sustain_track
bool
IfTRUE
energy consumption is tracked during training via the python library codecarbon.- sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.- sustain_region
Region within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
- sustain_interval
integer
Interval in seconds for measuring power usage.- trace
bool
TRUE
if information on the progress should be printed to the console.- keras_trace
int
keras_trace=0
does not print any information about the training process from keras on the console.keras_trace=1
prints a progress bar.keras_trace=2
prints one line of information for every epoch. Only relevant ifml_framework="tensorflow"
.- pytorch_trace
int
pytorch_trace=0
does not print any information about the training process from pytorch on the console.pytorch_trace=1
prints a progress bar.- pytorch_safetensors
bool
IfTRUE
a 'pytorch' model is saved in safetensors format. IfFALSE
or 'safetensors' not available it is saved in the standard pytorch format (.bin). Only relevant for pytorch models.
Value
This function does not return an object. Instead the trained or fine-tuned model is saved to disk.
Note
Pre-Trained models which can be fine-tuned with this function are available at https://huggingface.co/. New models can be created via the function create_roberta_model.
Training of this model makes use of dynamic masking.
References
Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., Levy, O., Lewis, M., Zettlemoyer, L., & Stoyanov, V. (2019). RoBERTa: A Robustly Optimized BERT Pretraining Approach. doi:10.48550/arXiv.1907.11692
Hugging Face Documentation https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaConfig
See also
Other Transformer:
create_bert_model()
,
create_deberta_v2_model()
,
create_funnel_model()
,
create_longformer_model()
,
create_roberta_model()
,
train_tune_bert_model()
,
train_tune_deberta_v2_model()
,
train_tune_funnel_model()
,
train_tune_longformer_model()