
Base class for classifiers relying on numerical representations of texts instead of words that use the architecture of Protonets and its corresponding training techniques.
Source:R/obj_TEClassifiersBasedOnProtoNet.R
TEClassifiersBasedOnProtoNet.Rd
Base class for classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings as input which use the architecture of Protonets and its corresponding training techniques.
Objects of this class containing fields and methods used in several other classes in 'AI for Education'.
This class is not designed for a direct application and should only be used by developers.
See also
Other R6 Classes for Developers:
AIFEBaseModel
,
ClassifiersBasedOnTextEmbeddings
,
LargeDataSetBase
,
ModelsBasedOnTextEmbeddings
,
TEClassifiersBasedOnRegular
Super classes
aifeducation::AIFEBaseModel
-> aifeducation::ModelsBasedOnTextEmbeddings
-> aifeducation::ClassifiersBasedOnTextEmbeddings
-> TEClassifiersBasedOnProtoNet
Methods
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$is_trained()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()
aifeducation::ClassifiersBasedOnTextEmbeddings$adjust_target_levels()
aifeducation::ClassifiersBasedOnTextEmbeddings$check_embedding_model()
aifeducation::ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type()
aifeducation::ClassifiersBasedOnTextEmbeddings$load_from_disk()
aifeducation::ClassifiersBasedOnTextEmbeddings$plot_coding_stream()
aifeducation::ClassifiersBasedOnTextEmbeddings$plot_training_history()
aifeducation::ClassifiersBasedOnTextEmbeddings$predict()
aifeducation::ClassifiersBasedOnTextEmbeddings$requires_compression()
aifeducation::ClassifiersBasedOnTextEmbeddings$save()
Method train()
Method for training a neural net.
Training includes a routine for early stopping. In the case that loss<0.0001 and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values of the last trained epoch for the remaining epochs.
After training the model with the best values for Average Iota, Accuracy, and Loss on the validation data set is used as the final model.
Usage
TEClassifiersBasedOnProtoNet$train(
data_embeddings = NULL,
data_targets = NULL,
data_folds = 5,
data_val_size = 0.25,
loss_pt_fct_name = "MultiWayContrastiveLoss",
use_sc = FALSE,
sc_method = "knnor",
sc_min_k = 1,
sc_max_k = 10,
use_pl = FALSE,
pl_max_steps = 3,
pl_max = 1,
pl_anchor = 1,
pl_min = 0,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 35,
Ns = 5,
Nq = 3,
loss_alpha = 0.5,
loss_margin = 0.05,
sampling_separate = FALSE,
sampling_shuffle = TRUE,
trace = TRUE,
ml_trace = 1,
log_dir = NULL,
log_write_interval = 10,
n_cores = auto_n_cores(),
lr_rate = 0.001,
lr_warm_up_ratio = 0.02,
optimizer = "AdamW"
)
Arguments
data_embeddings
EmbeddedText, LargeDataSetForTextEmbeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.data_targets
factor
containing the labels for cases stored in embeddings. Factor must be named and has to use the same names as used in in the embeddings. .data_folds
int
determining the number of cross-fold samples. Allowed values:1 <= x
data_val_size
double
between 0 and 1, indicating the proportion of cases which should be used for the validation sample during the estimation of the model. The remaining cases are part of the training data. Allowed values:0 < x < 1
loss_pt_fct_name
string
Name of the loss function to use during training. Allowed values: 'MultiWayContrastiveLoss'use_sc
bool
TRUE
if the estimation should integrate synthetic cases.FALSE
if not.sc_method
string
containing the method for generating synthetic cases. Allowed values: 'knnor'sc_min_k
int
determining the minimal number of k which is used for creating synthetic units. Allowed values:1 <= x
sc_max_k
int
determining the maximal number of k which is used for creating synthetic units. Allowed values:1 <= x
use_pl
bool
TRUE
if the estimation should integrate pseudo-labeling.FALSE
if not.pl_max_steps
int
determining the maximum number of steps during pseudo-labeling. Allowed values:1 <= x
pl_max
double
setting the maximal level of confidence for considering a case for pseudo-labeling. Allowed values:0 < x <= 1
pl_anchor
double
indicating the reference point for sorting the new cases of every label. Allowed values:0 <= x <= 1
pl_min
double
setting the mnimal level of confidence for considering a case for pseudo-labeling. Allowed values:0 <= x < 1
sustain_track
bool
IfTRUE
energy consumption is tracked during training via the python library 'codecarbon'.sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: anysustain_region
string
Region within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: anysustain_interval
int
Interval in seconds for measuring power usage. Allowed values:1 <= x
epochs
int
Number of training epochs. Allowed values:1 <= x
batch_size
int
Size of the batches for training. Allowed values:1 <= x
Ns
int
Number of cases for every class in the sample. Allowed values:1 <= x
Nq
int
Number of cases for every class in the query. Allowed values:1 <= x
loss_alpha
double
Value between 0 and 1 indicating how strong the loss should focus on pulling cases to its corresponding prototypes or pushing cases away from other prototypes. The higher the value the more the loss concentrates on pulling cases to its corresponding prototypes. Allowed values:0 <= x <= 1
loss_margin
double
Value greater 0 indicating the minimal distance of every case from prototypes of other classes. Please note that in contrast to the orginal work by Zhang et al. (2019) this implementation reaches better performance if the margin is a magnitude lower (e.g. 0.05 instead of 0.5). Allowed values:0 <= x <= 1
sampling_separate
bool
IfTRUE
the cases for every class are divided into a data set for sample and for query. These are never mixed. IfTRUE
sample and query cases are drawn from the same data pool. That is, a case can be part of sample in one epoch and in another epoch it can be part of query. It is ensured that a case is never part of sample and query at the same time. In addition, it is ensured that every cases exists only once during a training step.sampling_shuffle
bool
ifTRUE
cases a randomly drawn from the data during every step. IfFALSE
the cases are not shuffled.trace
bool
TRUE
if information about the estimation phase should be printed to the console.ml_trace
int
ml_trace=0
does not print any information about the training process from pytorch on the console. Allowed values:0 <= x <= 1
log_dir
string
Path to the directory where the log files should be saved. If no logging is desired set this argument toNULL
. Allowed values: anylog_write_interval
int
Time in seconds determining the interval in which the logger should try to update the log files. Only relevant iflog_dir
is notNULL
. Allowed values:1 <= x
n_cores
int
Number of cores which should be used during the calculation of synthetic cases. Only relevant ifuse_sc=TRUE
. Allowed values:1 <= x
lr_rate
double
Initial learning rate for the training. Allowed values:0 < x <= 1
lr_warm_up_ratio
double
Number of epochs used for warm up. Allowed values:0 < x < 0.5
optimizer
string
determining the optimizer used for training. Allowed values: 'Adam', 'RMSprop', 'AdamW', 'SGD'loss_balance_class_weights
bool
IfTRUE
class weights are generated based on the frequencies of the training data with the method Inverse Class Frequency. IfFALSE
each class has the weight 1.loss_balance_sequence_length
bool
IfTRUE
sample weights are generated for the length of sequences based on the frequencies of the training data with the method Inverse Class Frequency. IfFALSE
each sequences length has the weight 1.
Details
sc_max_k
: All values from sc_min_k up to sc_max_k are successively used. If the number ofsc_max_k
is too high, the value is reduced to a number that allows the calculating of synthetic units.pl_anchor:
With the help of this value, the new cases are sorted. For this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
Method predict_with_samples()
Method for predicting the class of given data (query) based on provided examples (sample).
Usage
TEClassifiersBasedOnProtoNet$predict_with_samples(
newdata,
batch_size = 32,
ml_trace = 1,
embeddings_s = NULL,
classes_s = NULL
)
Arguments
newdata
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be predicted. They form the query set.
batch_size
int
batch size.ml_trace
int
ml_trace=0
does not print any information about the training process from pytorch on the console. Allowed values:0 <= x <= 1
embeddings_s
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all reference examples. They form the sample set.
classes_s
Named
factor
containing the classes for every case withinembeddings_s
.
Method embed()
Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of texts created by an object of class TextEmbeddingModel. These embeddings embed documents according to their similarity to specific classes.
Usage
TEClassifiersBasedOnProtoNet$embed(
embeddings_q = NULL,
embeddings_s = NULL,
classes_s = NULL,
batch_size = 32,
ml_trace = 1
)
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
embeddings_s
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all reference examples. They form the sample set. If set to
NULL
the trained prototypes are used.classes_s
Named
factor
containing the classes for every case withinembeddings_s
. If set toNULL
the trained prototypes are used.batch_size
int
batch size.ml_trace
int
ml_trace=0
does not print any information about the training process from pytorch on the console. Allowed values:0 <= x <= 1
Returns
Returns a list
containing the following elements
embeddings_q
: embeddings for the cases (query sample).distances_q
:matrix
containing the distance of every query case to every prototype.embeddings_prototypes
: embeddings of the prototypes which were learned during training. They represents the center for the different classes.
Method plot_embeddings()
Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
Usage
TEClassifiersBasedOnProtoNet$plot_embeddings(
embeddings_q,
classes_q = NULL,
embeddings_s = NULL,
classes_s = NULL,
batch_size = 12,
alpha = 0.5,
size_points = 3,
size_points_prototypes = 8,
inc_unlabeled = TRUE,
inc_margin = TRUE
)
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
classes_q
Named
factor
containg the true classes for every case. Please note that the names must match the names/ids inembeddings_q
.embeddings_s
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all reference examples. They form the sample set. If set to
NULL
the trained prototypes are used.classes_s
Named
factor
containing the classes for every case withinembeddings_s
. If set toNULL
the trained prototypes are used.batch_size
int
batch size.alpha
float
Value indicating how transparent the points should be (important if many points overlap). Does not apply to points representing prototypes.size_points
int
Size of the points excluding the points for prototypes.size_points_prototypes
int
Size of points representing prototypes.inc_unlabeled
bool
IfTRUE
plot includes unlabeled cases as data points.inc_margin
bool
IfTRUE
plot includes the margin around every prototype. Adding margin requires a trained model. If the model is not trained this argument is treated as set toFALSE
.