This class contains all methods shared by all BaseModels.
See also
Other R6 Classes for Developers:
AIFEBaseModel,
AIFEMaster,
ClassifiersBasedOnTextEmbeddings,
DataManagerClassifier,
LargeDataSetBase,
ModelsBasedOnTextEmbeddings,
TEClassifiersBasedOnProtoNet,
TEClassifiersBasedOnRegular,
TokenizerBase
Super classes
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> BaseModelCore
Methods
Inherited methods
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()
Method create_from_hf()
Creates BaseModel from a pretrained model
Method train()
Traines a BaseModel
Usage
BaseModelCore$train(
text_dataset,
p_mask = 0.15,
whole_word = TRUE,
val_size = 0.1,
n_epoch = 1L,
batch_size = 12L,
max_sequence_length = 250L,
full_sequences_only = FALSE,
min_seq_len = 50L,
learning_rate = 0.003,
sustain_track = FALSE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
trace = TRUE,
pytorch_trace = 1L,
log_dir = NULL,
log_write_interval = 2L
)Arguments
text_datasetLargeDataSetForTextLargeDataSetForText Object storing textual data.p_maskdoubleRatio that determines the number of tokens used for masking. Allowed values: \(0.05 <= x <= 0.6\)whole_wordbool*TRUE: whole word masking should be applied. Only relevant if aWordPieceTokenizeris used.FALSE: token masking is used.
val_sizedoublebetween 0 and 1, indicating the proportion of cases which should be used for the validation sample during the estimation of the model. The remaining cases are part of the training data. Allowed values: \(0 < x < 1\)n_epochintNumber of training epochs. Allowed values: \(1 <= x \)batch_sizeintSize of the batches for training. Allowed values: \(1 <= x \)max_sequence_lengthintMaximal number of tokens for every sequence. Allowed values: \(20 <= x \)full_sequences_onlyboolTRUEfor using only chunks with a sequence length equal tochunk_size.min_seq_lenintOnly relevant iffull_sequences_only = FALSE. Value determines the minimal sequence length included in training process. Allowed values: \(10 <= x \)learning_ratedoubleInitial learning rate for the training. Allowed values: \(0 < x <= 1\)sustain_trackboolIfTRUEenergy consumption is tracked during training via the python library 'codecarbon'.sustain_iso_codestringISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: anysustain_regionstringRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: anysustain_intervalintInterval in seconds for measuring power usage. Allowed values: \(1 <= x \)sustain_log_levelstringLevel for printing information to the console. Allowed values: 'debug', 'info', 'warning', 'error', 'critical'traceboolTRUEif information about the estimation phase should be printed to the console.pytorch_traceintml_trace=0does not print any information about the training process from pytorch on the console. Allowed values: \(0 <= x <= 1\)log_dirstringPath to the directory where the log files should be saved. If no logging is desired set this argument toNULL. Allowed values: anylog_write_intervalintTime in seconds determining the interval in which the logger should try to update the log files. Only relevant iflog_diris notNULL. Allowed values: \(1 <= x \)
Method plot_training_history()
Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
Usage
BaseModelCore$plot_training_history(
x_min = NULL,
x_max = NULL,
y_min = NULL,
y_max = NULL,
ind_best_model = TRUE,
text_size = 10L
)Arguments
x_minintMinimal value for x-axis. Set toNULLfor an automatic adjustment. Allowed values: \( x \)x_maxintMaximal value for x-axis. Set toNULLfor an automatic adjustment. Allowed values: \( x \)y_minintMinimal value for y-axis. Set toNULLfor an automatic adjustment. Allowed values: \( x \)y_maxintMaximal value for y-axis. Set toNULLfor an automatic adjustment. Allowed values: \( x \)ind_best_modelboolIfTRUEthe plot indicates the best states of the model according to the chosen measure.text_sizeintSize of text elements. Allowed values: \(1 <= x \)
Method fill_mask()
Method for calculating tokens behind mask tokens.
Method save()
Method for saving a model on disk.
Method load_from_disk()
Loads an object from disk and updates the object to the current version of the package.
Method set_publication_info()
Method for setting the bibliographic information of the model.
Method estimate_sustainability_inference_fill_mask()
Calculates the energy consumption for inference of the given task.
Usage
BaseModelCore$estimate_sustainability_inference_fill_mask(
text_dataset = NULL,
n_samples = NULL,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
trace = TRUE
)Arguments
text_datasetLargeDataSetForTextLargeDataSetForText Object storing textual data.n_samplesintNumber of samples. Allowed values: \(1 <= x \)sustain_iso_codestringISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: anysustain_regionstringRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: anysustain_intervalintInterval in seconds for measuring power usage. Allowed values: \(1 <= x \)sustain_log_levelstringLevel for printing information to the console. Allowed values: 'debug', 'info', 'warning', 'error', 'critical'traceboolTRUEif information about the estimation phase should be printed to the console.
Method calc_flops_architecture_based()
Calculates FLOPS based on model's architecture.