Function for training and fine-tuning a Funnel Transformer model
Source:R/transformer_funnel.R
train_tune_funnel_model.Rd
This function can be used to train or fine-tune a transformer based on Funnel Transformer architecture with the help of the python libraries 'transformers', 'datasets', and 'tokenizers'.
Usage
train_tune_funnel_model(
ml_framework = aifeducation_config$get_framework(),
output_dir,
model_dir_path,
raw_texts,
p_mask = 0.15,
whole_word = TRUE,
val_size = 0.1,
n_epoch = 1,
batch_size = 12,
chunk_size = 250,
min_seq_len = 50,
full_sequences_only = FALSE,
learning_rate = 0.003,
n_workers = 1,
multi_process = FALSE,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
trace = TRUE,
keras_trace = 1,
pytorch_trace = 1,
pytorch_safetensors = TRUE
)
Arguments
- ml_framework
string
Framework to use for training and inference.ml_framework="tensorflow"
for 'tensorflow' andml_framework="pytorch"
for 'pytorch'.- output_dir
string
Path to the directory where the final model should be saved. If the directory does not exist, it will be created.- model_dir_path
string
Path to the directory where the original model is stored.- raw_texts
vector
containing the raw texts for training.- p_mask
double
Ratio determining the number of words/tokens for masking.- whole_word
bool
TRUE
if whole word masking should be applied. IfFALSE
token masking is used.- val_size
double
Ratio determining the amount of token chunks used for validation.- n_epoch
int
Number of epochs for training.- batch_size
int
Size of batches.- chunk_size
int
Size of every chunk for training.- min_seq_len
int
Only relevant iffull_sequences_only=FALSE
. Value determines the minimal sequence length for inclusion in training process.- full_sequences_only
bool
TRUE
if only token sequences with a length equal tochunk_size
should be used for training.- learning_rate
double
Learning rate for adam optimizer.- n_workers
int
Number of workers.- multi_process
bool
TRUE
if multiple processes should be activated.- sustain_track
bool
IfTRUE
energy consumption is tracked during training via the python library codecarbon.- sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.- sustain_region
Region within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
- sustain_interval
integer
Interval in seconds for measuring power usage.- trace
bool
TRUE
if information on the progress should be printed to the console.- keras_trace
int
keras_trace=0
does not print any information about the training process from keras on the console.keras_trace=1
prints a progress bar.keras_trace=2
prints one line of information for every epoch.- pytorch_trace
int
pytorch_trace=0
does not print any information about the training process from pytorch on the console.pytorch_trace=1
prints a progress bar.- pytorch_safetensors
bool
IfTRUE
a 'pytorch' model is saved in safetensors format. IfFALSE
or 'safetensors' not available it is saved in the standard pytorch format (.bin). Only relevant for pytorch models.
Value
This function does not return an object. Instead the trained or fine-tuned model is saved to disk.
Note
if aug_vocab_by > 0
the raw text is used for training a WordPiece
tokenizer. At the end of this process, additional entries are added to the vocabulary
that are not part of the original vocabulary. This is in an experimental state.
Pre-Trained models which can be fine-tuned with this function are available at https://huggingface.co/.
New models can be created via the function create_funnel_model.
Training of the model makes use of dynamic masking.
References
Dai, Z., Lai, G., Yang, Y. & Le, Q. V. (2020). Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing. doi:10.48550/arXiv.2006.03236
Hugging Face documentation https://huggingface.co/docs/transformers/model_doc/funnel#funnel-transformer
See also
Other Transformer:
create_bert_model()
,
create_deberta_v2_model()
,
create_funnel_model()
,
create_longformer_model()
,
create_roberta_model()
,
train_tune_bert_model()
,
train_tune_deberta_v2_model()
,
train_tune_longformer_model()
,
train_tune_roberta_model()