Function for creating a new transformer based on Funnel Transformer
Source:R/transformer_funnel.R
create_funnel_model.Rd
This function creates a transformer configuration based on the Funnel Transformer base architecture and a vocabulary based on WordPiece by using the python libraries 'transformers' and 'tokenizers'.
Usage
create_funnel_model(
ml_framework = aifeducation_config$get_framework(),
model_dir,
vocab_raw_texts = NULL,
vocab_size = 30522,
vocab_do_lower_case = FALSE,
max_position_embeddings = 512,
hidden_size = 768,
target_hidden_size = 64,
block_sizes = c(4, 4, 4),
num_attention_heads = 12,
intermediate_size = 3072,
num_decoder_layers = 2,
pooling_type = "mean",
hidden_act = "gelu",
hidden_dropout_prob = 0.1,
attention_probs_dropout_prob = 0.1,
activation_dropout = 0,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
trace = TRUE,
pytorch_safetensors = TRUE
)
Arguments
- ml_framework
string
Framework to use for training and inference.ml_framework="tensorflow"
for 'tensorflow' andml_framework="pytorch"
for 'pytorch'.- model_dir
string
Path to the directory where the model should be saved.- vocab_raw_texts
vector
containing the raw texts for creating the vocabulary.- vocab_size
int
Size of the vocabulary.- vocab_do_lower_case
bool
TRUE
if all words/tokens should be lower case.- max_position_embeddings
int
Number of maximal position embeddings. This parameter also determines the maximum length of a sequence which can be processed with the model.- hidden_size
int
Initial number of neurons in each layer.- target_hidden_size
int
Number of neurons in the final layer. This parameter determines the dimensionality of the resulting text embedding.- block_sizes
vector
ofint
determining the number and sizes of each block.- num_attention_heads
int
Number of attention heads.- intermediate_size
int
Number of neurons in the intermediate layer of the attention mechanism.- num_decoder_layers
int
Number of decoding layers.- pooling_type
string
"mean"
for pooling with mean and"max"
for pooling with maximum values.- hidden_act
string
name of the activation function.- hidden_dropout_prob
double
Ratio of dropout.- attention_probs_dropout_prob
double
Ratio of dropout for attention probabilities.- activation_dropout
float
Dropout probability between the layers of the feed-forward blocks.- sustain_track
bool
IfTRUE
energy consumption is tracked during training via the python library codecarbon.- sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.- sustain_region
Region within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
- sustain_interval
integer
Interval in seconds for measuring power usage.- trace
bool
TRUE
if information about the progress should be printed to the console.- pytorch_safetensors
bool
IfTRUE
a 'pytorch' model is saved in safetensors format. IfFALSE
or 'safetensors' not available it is saved in the standard pytorch format (.bin). Only relevant for pytorch models.
Value
This function does not return an object. Instead the configuration and the vocabulary of the new model are saved on disk.
Note
The model uses a configuration with truncate_seq=TRUE
to avoid
implementation problems with tensorflow.
To train the model, pass the directory of the model to the function train_tune_funnel_model.
Model is created with separete_cls=TRUE
,truncate_seq=TRUE
, and
pool_q_only=TRUE
.
This models uses a WordPiece Tokenizer like BERT and can be trained with whole word masking. Transformer library may show a warning which can be ignored.
References
Dai, Z., Lai, G., Yang, Y. & Le, Q. V. (2020). Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing. doi:10.48550/arXiv.2006.03236
Hugging Face documentation https://huggingface.co/docs/transformers/model_doc/funnel#funnel-transformer
See also
Other Transformer:
create_bert_model()
,
create_deberta_v2_model()
,
create_longformer_model()
,
create_roberta_model()
,
train_tune_bert_model()
,
train_tune_deberta_v2_model()
,
train_tune_funnel_model()
,
train_tune_longformer_model()
,
train_tune_roberta_model()