This class contains all methods shared by all BaseModels.
See also
Other R6 Classes for Developers:
AIFEBaseModel,
AIFEMaster,
ClassifiersBasedOnTextEmbeddings,
DataManagerClassifier,
LargeDataSetBase,
ModelsBasedOnTextEmbeddings,
TEClassifiersBasedOnProtoNet,
TEClassifiersBasedOnRegular,
TokenizerBase
Super classes
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> BaseModelCore
Methods
Inherited methods
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()
Method create_from_hf()
Creates BaseModel from a pretrained model
Method train()
Traines a BaseModel
Usage
BaseModelCore$train(
text_dataset,
p_mask = 0.15,
whole_word = TRUE,
val_size = 0.1,
n_epoch = 1L,
batch_size = 12L,
max_sequence_length = 250L,
full_sequences_only = FALSE,
min_seq_len = 50L,
learning_rate = 0.003,
sustain_track = FALSE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
trace = TRUE,
pytorch_trace = 1L,
log_dir = NULL,
log_write_interval = 2L
)Method plot_training_history()
Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
Method save()
Method for saving a model on disk.
Method load_from_disk()
Loads an object from disk and updates the object to the current version of the package.
Method set_publication_info()
Method for setting the bibliographic information of the model.
Method estimate_sustainability_inference_fill_mask()
Calculates the energy consumption for inference of the given task.