Base class for regular classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings as input
Source:R/obj_TEClassifiersBasedOnRegular.R
TEClassifiersBasedOnRegular.RdAbstract class for all regular classifiers that use numerical representations of texts instead of words.
Objects of this class containing fields and methods used in several other classes in 'AI for Education'.
This class is not designed for a direct application and should only be used by developers.
See also
Other R6 Classes for Developers:
AIFEBaseModel,
AIFEMaster,
BaseModelCore,
ClassifiersBasedOnTextEmbeddings,
DataManagerClassifier,
LargeDataSetBase,
ModelsBasedOnTextEmbeddings,
TEClassifiersBasedOnProtoNet,
TokenizerBase
Super classes
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> aifeducation::ModelsBasedOnTextEmbeddings -> aifeducation::ClassifiersBasedOnTextEmbeddings -> TEClassifiersBasedOnRegular
Methods
Inherited methods
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()aifeducation::AIFEMaster$set_publication_info()aifeducation::AIFEBaseModel$count_parameter()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()aifeducation::ClassifiersBasedOnTextEmbeddings$adjust_target_levels()aifeducation::ClassifiersBasedOnTextEmbeddings$check_embedding_model()aifeducation::ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type()aifeducation::ClassifiersBasedOnTextEmbeddings$load_from_disk()aifeducation::ClassifiersBasedOnTextEmbeddings$plot_coding_stream()aifeducation::ClassifiersBasedOnTextEmbeddings$plot_training_history()aifeducation::ClassifiersBasedOnTextEmbeddings$predict()aifeducation::ClassifiersBasedOnTextEmbeddings$requires_compression()aifeducation::ClassifiersBasedOnTextEmbeddings$save()
Method train()
Method for training a neural net.
Training includes a routine for early stopping. In the case that loss<0.0001 and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values of the last trained epoch for the remaining epochs.
After training the model with the best values for Average Iota, Accuracy, and Loss on the validation data set is used as the final model.
Usage
TEClassifiersBasedOnRegular$train(
data_embeddings = NULL,
data_targets = NULL,
data_folds = 5L,
data_val_size = 0.25,
loss_balance_class_weights = TRUE,
loss_balance_sequence_length = TRUE,
loss_cls_fct_name = "FocalLoss",
use_sc = FALSE,
sc_method = "knnor",
sc_min_k = 1L,
sc_max_k = 10L,
use_pl = FALSE,
pl_max_steps = 3L,
pl_max = 1,
pl_anchor = 1,
pl_min = 0,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15L,
sustain_log_level = "warning",
epochs = 40L,
batch_size = 32L,
trace = TRUE,
ml_trace = 1L,
log_dir = NULL,
log_write_interval = 10L,
n_cores = auto_n_cores(),
lr_rate = 0.001,
lr_warm_up_ratio = 0.02,
optimizer = "AdamW"
)Arguments
data_embeddingsEmbeddedText, LargeDataSetForTextEmbeddingsObject of class EmbeddedText or LargeDataSetForTextEmbeddings.data_targetsfactorcontaining the labels for cases stored in embeddings. Factor must be named and has to use the same names as used in in the embeddings. .data_foldsintdetermining the number of cross-fold samples. Allowed values:1 <= xdata_val_sizedoublebetween 0 and 1, indicating the proportion of cases which should be used for the validation sample during the estimation of the model. The remaining cases are part of the training data. Allowed values:0 < x < 1loss_balance_class_weightsboolIfTRUEclass weights are generated based on the frequencies of the training data with the method Inverse Class Frequency. IfFALSEeach class has the weight 1.loss_balance_sequence_lengthboolIfTRUEsample weights are generated for the length of sequences based on the frequencies of the training data with the method Inverse Class Frequency. IfFALSEeach sequences length has the weight 1.loss_cls_fct_namestringName of the loss function to use during training. Allowed values: 'FocalLoss', 'CrossEntropyLoss'use_scboolTRUEif the estimation should integrate synthetic cases.FALSEif not.sc_methodstringcontaining the method for generating synthetic cases. Allowed values: 'knnor'sc_min_kintdetermining the minimal number of k which is used for creating synthetic units. Allowed values:1 <= xsc_max_kintdetermining the maximal number of k which is used for creating synthetic units. Allowed values:1 <= xuse_plboolTRUEif the estimation should integrate pseudo-labeling.FALSEif not.pl_max_stepsintdetermining the maximum number of steps during pseudo-labeling. Allowed values:1 <= xpl_maxdoublesetting the maximal level of confidence for considering a case for pseudo-labeling. Allowed values:0 < x <= 1pl_anchordoubleindicating the reference point for sorting the new cases of every label. Allowed values:0 <= x <= 1pl_mindoublesetting the mnimal level of confidence for considering a case for pseudo-labeling. Allowed values:0 <= x < 1sustain_trackboolIfTRUEenergy consumption is tracked during training via the python library 'codecarbon'.sustain_iso_codestringISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: anysustain_regionstringRegion within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: anysustain_intervalintInterval in seconds for measuring power usage. Allowed values:1 <= xsustain_log_levelepochsintNumber of training epochs. Allowed values:1 <= xbatch_sizeintSize of the batches for training. Allowed values:1 <= xtraceboolTRUEif information about the estimation phase should be printed to the console.ml_traceintml_trace=0does not print any information about the training process from pytorch on the console. Allowed values:0 <= x <= 1log_dirstringPath to the directory where the log files should be saved. If no logging is desired set this argument toNULL. Allowed values: anylog_write_intervalintTime in seconds determining the interval in which the logger should try to update the log files. Only relevant iflog_diris notNULL. Allowed values:1 <= xn_coresintNumber of cores which should be used during the calculation of synthetic cases. Only relevant ifuse_sc=TRUE. Allowed values:1 <= xlr_ratedoubleInitial learning rate for the training. Allowed values:0 < x <= 1lr_warm_up_ratiodoubleNumber of epochs used for warm up. Allowed values:0 < x < 0.5optimizerstringdetermining the optimizer used for training. Allowed values: 'Adam', 'RMSprop', 'AdamW', 'SGD'
Details
sc_max_k: All values from sc_min_k up to sc_max_k are successively used. If the number of sc_max_k is too high, the value is reduced to a number that allows the calculating of synthetic units.pl_anchor: With the help of this value, the new cases are sorted. For this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.