
Abstract class for all classifiers that use numerical representations of texts instead of words.
Source:R/obj_ClassifiersBasedOnTextEmbeddings.R
ClassifiersBasedOnTextEmbeddings.Rd
Base class for classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings generated with a TextEmbeddingModel.
Objects of this class containing fields and methods used in several other classes in 'AI for Education'.
This class is not designed for a direct application and should only be used by developers.
See also
Other R6 Classes for Developers:
AIFEBaseModel
,
LargeDataSetBase
,
ModelsBasedOnTextEmbeddings
,
TEClassifiersBasedOnProtoNet
,
TEClassifiersBasedOnRegular
Super classes
aifeducation::AIFEBaseModel
-> aifeducation::ModelsBasedOnTextEmbeddings
-> ClassifiersBasedOnTextEmbeddings
Public fields
feature_extractor
('list()')
List for storing information and objects about the feature_extractor.reliability
('list()')
List for storing central reliability measures of the last training.
reliability$test_metric
: Array containing the reliability measures for the test data for every fold and step (in case of pseudo-labeling).reliability$test_metric_mean
: Array containing the reliability measures for the test data. The values represent the mean values for every fold.reliability$raw_iota_objects
: List containing all iota_object generated with the packageiotarelr
for every fold at the end of the last training for the test data.reliability$raw_iota_objects$iota_objects_end
: List of objects with classiotarelr_iota2
containing the estimated iota reliability of the second generation for the final model for every fold for the test data.reliability$raw_iota_objects$iota_objects_end_free
: List of objects with classiotarelr_iota2
containing the estimated iota reliability of the second generation for the final model for every fold for the test data. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.reliability$iota_object_end
: Object of classiotarelr_iota2
as a mean of the individual objects for every fold for the test data.reliability$iota_object_end_free
: Object of classiotarelr_iota2
as a mean of the individual objects for every fold. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.reliability$standard_measures_end
: Object of classlist
containing the final measures for precision, recall, and f1 for every fold.reliability$standard_measures_mean
:matrix
containing the mean measures for precision, recall, and f1.
Methods
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$is_trained()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()
Method predict()
Method for predicting new data with a trained neural net.
Arguments
newdata
Object of class TextEmbeddingModel or LargeDataSetForTextEmbeddings for which predictions should be made. In addition, this method allows to use objects of class
array
anddatasets.arrow_dataset.Dataset
. However, these should be used only by developers.batch_size
int
Size of batches.ml_trace
int
ml_trace=0
does not print any information on the process from the machine learning framework.
Method check_embedding_model()
Method for checking if the provided text embeddings are created with the same TextEmbeddingModel as the classifier.
Usage
ClassifiersBasedOnTextEmbeddings$check_embedding_model(
text_embeddings,
require_compressed = FALSE
)
Arguments
text_embeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.
require_compressed
TRUE
if a compressed version of the embeddings are necessary. Compressed embeddings are created by an object of class TEFeatureExtractor.
Returns
TRUE
if the underlying TextEmbeddingModel is the same. FALSE
if the models differ.
Method check_feature_extractor_object_type()
Method for checking an object of class TEFeatureExtractor.
Arguments
feature_extractor
Object of class TEFeatureExtractor
Method requires_compression()
Method for checking if provided text embeddings must be compressed via a TEFeatureExtractor before processing.
Arguments
text_embeddings
Object of class EmbeddedText, LargeDataSetForTextEmbeddings,
array
ordatasets.arrow_dataset.Dataset
.
Method save()
Method for saving a model.
Method load_from_disk()
loads an object from disk and updates the object to the current version of the package.
Method adjust_target_levels()
Method transforms the levels of a factor into numbers corresponding to the models definition.
Method plot_training_history()
Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
Usage
ClassifiersBasedOnTextEmbeddings$plot_training_history(
final_training = FALSE,
pl_step = NULL,
measure = "loss",
y_min = NULL,
y_max = NULL,
add_min_max = TRUE,
text_size = 10
)
Arguments
final_training
bool
IfFALSE
the values of the performance estimation are used. IfTRUE
only the epochs of the final training are used.pl_step
int
Number of the step during pseudo labeling to plot. Only relevant if the model was trained with active pseudo labeling.measure
string
Measure to plot. Allowed values:"avg_iota"
= Average Iota"loss"
= Loss"accuracy"
= Accuracy"balanced_accuracy"
= Balanced Accuracy
y_min
Minimal value for the y-axis. Set to
NULL
for an automatic adjustment.y_max
Maximal value for the y-axis. Set to
NULL
for an automatic adjustment.add_min_max
bool
IfTRUE
the minimal and maximal values during performance estimation are port of the plot. IfFALSE
only the mean values are shown. Parameter is ignored iffinal_training=TRUE
.text_size
Size of the text.
Method plot_coding_stream()
Method for requesting a plot the coding stream. The plot shows how the cases of different categories/classes are assigned to a the available classes/categories. The visualization is helpful for analyzing the consequences of coding errors.
Usage
ClassifiersBasedOnTextEmbeddings$plot_coding_stream(
label_categories_size = 3,
key_size = 0.5,
text_size = 10
)