Skip to contents

Abstract class for auto encoders with 'pytorch'.

Value

Objects of this class are used for reducing the number of dimensions of text embeddings created by an object of class TextEmbeddingModel.

For training an object of class EmbeddedText or LargeDataSetForTextEmbeddings generated by an object of class TextEmbeddingModel is necessary. Passing raw texts is not supported.

For prediction an ob object class EmbeddedText or LargeDataSetForTextEmbeddings is necessary that was generated with the same TextEmbeddingModel as during training. Prediction outputs a new object of class EmbeddedText or LargeDataSetForTextEmbeddings which contains a text embedding with a lower number of dimensions.

All models use tied weights for the encoder and decoder layers (except method="lstm") and apply the estimation of orthogonal weights. In addition, training tries to train the model to achieve uncorrelated features.

Objects of class TEFeatureExtractor are designed to be used with classifiers such as TEClassifierRegular and TEClassifierProtoNet.

See also

Other Text Embedding: TextEmbeddingModel

Super class

aifeducation::AIFEBaseModel -> TEFeatureExtractor

Methods

Inherited methods


Method configure()

Creating a new instance of this class.

Usage

TEFeatureExtractor$configure(
  ml_framework = "pytorch",
  name = NULL,
  label = NULL,
  text_embeddings = NULL,
  features = 128,
  method = "lstm",
  noise_factor = 0.2,
  optimizer = "adam"
)

Arguments

ml_framework

string Framework to use for training and inference. Currently only ml_framework="pytorch" is supported.

name

string Name of the new classifier. Please refer to common name conventions. Free text can be used with parameter label.

label

string Label for the new classifier. Here you can use free text.

text_embeddings

An object of class EmbeddedText or LargeDataSetForTextEmbeddings.

features

int determining the number of dimensions to which the dimension of the text embedding should be reduced.

method

string Method to use for the feature extraction. "lstm" for an extractor based on LSTM-layers or "dense" for dense layers.

noise_factor

double between 0 and a value lower 1 indicating how much noise should be added for the training of the feature extractor.

optimizer

string "adam" or "rmsprop" .

Returns

Returns an object of class TEFeatureExtractor which is ready for training.


Method train()

Method for training a neural net.

Usage

TEFeatureExtractor$train(
  data_embeddings,
  data_val_size = 0.25,
  sustain_track = TRUE,
  sustain_iso_code = NULL,
  sustain_region = NULL,
  sustain_interval = 15,
  epochs = 40,
  batch_size = 32,
  dir_checkpoint,
  trace = TRUE,
  ml_trace = 1,
  log_dir = NULL,
  log_write_interval = 10
)

Arguments

data_embeddings

Object of class EmbeddedText or LargeDataSetForTextEmbeddings.

data_val_size

double between 0 and 1, indicating the proportion of cases which should be used for the validation sample.

sustain_track

bool If TRUE energy consumption is tracked during training via the python library 'codecarbon'.

sustain_iso_code

string ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.

sustain_region

Region within a country. Only available for USA and Canada See the documentation of 'codecarbon' for more information. https://mlco2.github.io/codecarbon/parameters.html

sustain_interval

int Interval in seconds for measuring power usage.

epochs

int Number of training epochs.

batch_size

int Size of batches.

dir_checkpoint

string Path to the directory where the checkpoint during training should be saved. If the directory does not exist, it is created.

trace

bool TRUE, if information about the estimation phase should be printed to the console.

ml_trace

int ml_trace=0 does not print any information about the training process from pytorch on the console. ml_trace=1 prints a progress bar.

log_dir

string Path to the directory where the log files should be saved. If no logging is desired set this argument to NULL.

log_write_interval

int Time in seconds determining the interval in which the logger should try to update the log files. Only relevant if log_dir is not NULL.

Returns

Function does not return a value. It changes the object into a trained classifier.


Method load_from_disk()

loads an object from disk and updates the object to the current version of the package.

Usage

TEFeatureExtractor$load_from_disk(dir_path)

Arguments

dir_path

Path where the object set is stored.

Returns

Method does not return anything. It loads an object from disk.


Method extract_features()

Method for extracting features. Applying this method reduces the number of dimensions of the text embeddings. Please note that this method should only be used if a small number of cases should be compressed since the data is loaded completely into memory. For a high number of cases please use the method extract_features_large.

Usage

TEFeatureExtractor$extract_features(data_embeddings, batch_size)

Arguments

data_embeddings

Object of class EmbeddedText,LargeDataSetForTextEmbeddings, datasets.arrow_dataset.Dataset or array containing the text embeddings which should be reduced in their dimensions.

batch_size

int batch size.

Returns

Returns an object of class EmbeddedText containing the compressed embeddings.


Method extract_features_large()

Method for extracting features from a large number of cases. Applying this method reduces the number of dimensions of the text embeddings.

Usage

TEFeatureExtractor$extract_features_large(
  data_embeddings,
  batch_size,
  trace = FALSE
)

Arguments

data_embeddings

Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings which should be reduced in their dimensions.

batch_size

int batch size.

trace

bool If TRUE information about the progress is printed to the console.

Returns

Returns an object of class LargeDataSetForTextEmbeddings containing the compressed embeddings.


Method is_trained()

Check if the TEFeatureExtractor is trained.

Usage

TEFeatureExtractor$is_trained()

Returns

Returns TRUE if the object is trained and FALSE if not.


Method clone()

The objects of this class are cloneable with this method.

Usage

TEFeatureExtractor$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.