Skip to contents

Base class for classifiers relying on EmbeddedText or LargeDataSetForTextEmbeddings generated with a TextEmbeddingModel.

Objects of this class containing fields and methods used in several other classes in 'AI for Education'.

This class is not designed for a direct application and should only be used by developers.

Value

A new object of this class.

Super classes

aifeducation::AIFEBaseModel -> aifeducation::ModelsBasedOnTextEmbeddings -> ClassifiersBasedOnTextEmbeddings

Public fields

feature_extractor

('list()')
List for storing information and objects about the feature_extractor.

reliability

('list()')

List for storing central reliability measures of the last training.

  • reliability$test_metric: Array containing the reliability measures for the test data for every fold and step (in case of pseudo-labeling).

  • reliability$test_metric_mean: Array containing the reliability measures for the test data. The values represent the mean values for every fold.

  • reliability$raw_iota_objects: List containing all iota_object generated with the package iotarelr for every fold at the end of the last training for the test data.

  • reliability$raw_iota_objects$iota_objects_end: List of objects with class iotarelr_iota2 containing the estimated iota reliability of the second generation for the final model for every fold for the test data.

  • reliability$raw_iota_objects$iota_objects_end_free: List of objects with class iotarelr_iota2 containing the estimated iota reliability of the second generation for the final model for every fold for the test data. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.

  • reliability$iota_object_end: Object of class iotarelr_iota2 as a mean of the individual objects for every fold for the test data.

  • reliability$iota_object_end_free: Object of class iotarelr_iota2 as a mean of the individual objects for every fold. Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the assumption of weak superiority.

  • reliability$standard_measures_end: Object of class list containing the final measures for precision, recall, and f1 for every fold.

  • reliability$standard_measures_mean: matrix containing the mean measures for precision, recall, and f1.

Methods

Inherited methods


Method predict()

Method for predicting new data with a trained neural net.

Usage

ClassifiersBasedOnTextEmbeddings$predict(
  newdata,
  batch_size = 32,
  ml_trace = 1
)

Arguments

newdata

Object of class TextEmbeddingModel or LargeDataSetForTextEmbeddings for which predictions should be made. In addition, this method allows to use objects of class array and datasets.arrow_dataset.Dataset. However, these should be used only by developers.

batch_size

int Size of batches.

ml_trace

int ml_trace=0 does not print any information on the process from the machine learning framework.

Returns

Returns a data.frame containing the predictions and the probabilities of the different labels for each case.


Method check_embedding_model()

Method for checking if the provided text embeddings are created with the same TextEmbeddingModel as the classifier.

Usage

ClassifiersBasedOnTextEmbeddings$check_embedding_model(
  text_embeddings,
  require_compressed = FALSE
)

Arguments

text_embeddings

Object of class EmbeddedText or LargeDataSetForTextEmbeddings.

require_compressed

TRUE if a compressed version of the embeddings are necessary. Compressed embeddings are created by an object of class TEFeatureExtractor.

Returns

TRUE if the underlying TextEmbeddingModel is the same. FALSE if the models differ.


Method check_feature_extractor_object_type()

Method for checking an object of class TEFeatureExtractor.

Usage

ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type(
  feature_extractor
)

Arguments

feature_extractor

Object of class TEFeatureExtractor

Returns

This method does nothing returns. It raises an error if

  • the object is NULL

  • the object does not rely on the same machine learning framework as the classifier

  • the object is not trained.


Method requires_compression()

Method for checking if provided text embeddings must be compressed via a TEFeatureExtractor before processing.

Usage

ClassifiersBasedOnTextEmbeddings$requires_compression(text_embeddings)

Arguments

text_embeddings

Object of class EmbeddedText, LargeDataSetForTextEmbeddings, array or datasets.arrow_dataset.Dataset.

Returns

Return TRUE if a compression is necessary and FALSE if not.


Method save()

Method for saving a model.

Usage

ClassifiersBasedOnTextEmbeddings$save(dir_path, folder_name)

Arguments

dir_path

string Path of the directory where the model should be saved.

folder_name

string Name of the folder that should be created within the directory.

Returns

Function does not return a value. It saves the model to disk.


Method load_from_disk()

loads an object from disk and updates the object to the current version of the package.

Usage

ClassifiersBasedOnTextEmbeddings$load_from_disk(dir_path)

Arguments

dir_path

Path where the object set is stored.

Returns

Method does not return anything. It loads an object from disk.


Method adjust_target_levels()

Method transforms the levels of a factor into numbers corresponding to the models definition.

Usage

ClassifiersBasedOnTextEmbeddings$adjust_target_levels(data_targets)

Arguments

data_targets

factor containing the labels for cases stored in embeddings. Factor must be named and has to use the same names as used in in the embeddings.

Returns

Method returns a factor containing the numerical representation of categories/classes.


Method plot_training_history()

Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.

Usage

ClassifiersBasedOnTextEmbeddings$plot_training_history(
  final_training = FALSE,
  pl_step = NULL,
  measure = "loss",
  y_min = NULL,
  y_max = NULL,
  add_min_max = TRUE,
  text_size = 10
)

Arguments

final_training

bool If FALSE the values of the performance estimation are used. If TRUE only the epochs of the final training are used.

pl_step

int Number of the step during pseudo labeling to plot. Only relevant if the model was trained with active pseudo labeling.

measure

string Measure to plot. Allowed values:

  • "avg_iota" = Average Iota

  • "loss" = Loss

  • "accuracy" = Accuracy

  • "balanced_accuracy" = Balanced Accuracy

y_min

Minimal value for the y-axis. Set to NULL for an automatic adjustment.

y_max

Maximal value for the y-axis. Set to NULL for an automatic adjustment.

add_min_max

bool If TRUE the minimal and maximal values during performance estimation are port of the plot. If FALSE only the mean values are shown. Parameter is ignored if final_training=TRUE.

text_size

Size of the text.

Returns

Returns a plot of class ggplot visualizing the training process.


Method plot_coding_stream()

Method for requesting a plot the coding stream. The plot shows how the cases of different categories/classes are assigned to a the available classes/categories. The visualization is helpful for analyzing the consequences of coding errors.

Usage

ClassifiersBasedOnTextEmbeddings$plot_coding_stream(
  label_categories_size = 3,
  key_size = 0.5,
  text_size = 10
)

Arguments

label_categories_size

double determining the size of the label for each true and assigned category within the plot.

key_size

double determining the size of the legend.

text_size

double determining the size of the text within the legend.

Returns

Returns a plot of class ggplot visualizing the training process.


Method clone()

The objects of this class are cloneable with this method.

Usage

ClassifiersBasedOnTextEmbeddings$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.