Abstract class for neural nets with 'keras'/'tensorflow' and 'pytorch'.
This object represents in implementation of a prototypical network for few-shot learning as described by Snell, Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
Value
Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
classifier an object of class EmbeddedText or LargeDataSetForTextEmbeddings and a factor
are necessary. The
object of class EmbeddedText or LargeDataSetForTextEmbeddings contains the numerical text representations (text
embeddings) of the raw texts generated by an object of class TextEmbeddingModel. The factor
contains the
classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
class EmbeddedText or LargeDataSetForTextEmbeddings has to be used which was created with the same
TextEmbeddingModel as for training.
References
Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning. https://doi.org/10.48550/arXiv.1703.05175
Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H. Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing. https://doi.org/10.1007/978-3-030-16145-3_24
See also
Other Classification:
TEClassifierRegular
Super classes
aifeducation::AIFEBaseModel
-> aifeducation::TEClassifierRegular
-> TEClassifierProtoNet
Methods
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$get_text_embedding_model()
aifeducation::AIFEBaseModel$get_text_embedding_model_name()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
aifeducation::TEClassifierRegular$check_embedding_model()
aifeducation::TEClassifierRegular$check_feature_extractor_object_type()
aifeducation::TEClassifierRegular$load_from_disk()
aifeducation::TEClassifierRegular$predict()
aifeducation::TEClassifierRegular$requires_compression()
aifeducation::TEClassifierRegular$save()
Method configure()
Creating a new instance of this class.
Usage
TEClassifierProtoNet$configure(
ml_framework = "pytorch",
name = NULL,
label = NULL,
text_embeddings = NULL,
feature_extractor = NULL,
target_levels = NULL,
dense_size = 4,
dense_layers = 0,
rec_size = 4,
rec_layers = 2,
rec_type = "gru",
rec_bidirectional = FALSE,
embedding_dim = 2,
self_attention_heads = 0,
intermediate_size = NULL,
attention_type = "fourier",
add_pos_embedding = TRUE,
rec_dropout = 0.1,
repeat_encoder = 1,
dense_dropout = 0.4,
recurrent_dropout = 0.4,
encoder_dropout = 0.1,
optimizer = "adam"
)
Arguments
ml_framework
string
Currently only pytorch is supported (ml_framework="pytorch"
).name
string
Name of the new classifier. Please refer to common name conventions. Free text can be used with parameterlabel
.label
string
Label for the new classifier. Here you can use free text.text_embeddings
An object of class TextEmbeddingModel or LargeDataSetForTextEmbeddings.
feature_extractor
Object of class TEFeatureExtractor which should be used in order to reduce the number of dimensions of the text embeddings. If no feature extractor should be applied set
NULL
.target_levels
vector
containing the levels (categories or classes) within the target data. Please not that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels indicating a higher category/class. For nominal data the order does not matter.dense_size
int
Number of neurons for each dense layer.dense_layers
int
Number of dense layers.rec_size
int
Number of neurons for each recurrent layer.rec_layers
int
Number of recurrent layers.rec_type
string
Type of the recurrent layers.rec_type="gru"
for Gated Recurrent Unit andrec_type="lstm"
for Long Short-Term Memory.rec_bidirectional
bool
IfTRUE
a bidirectional version of the recurrent layers is used.embedding_dim
int
determining the number of dimensions for the text embedding.self_attention_heads
int
determining the number of attention heads for a self-attention layer. Only relevant ifattention_type="multihead"
.intermediate_size
int
determining the size of the projection layer within a each transformer encoder.attention_type
string
Choose the relevant attention type. Possible values are"fourier"
and"multihead"
. Please note that you may see different values for a case for different input orders if you choosefourier
on linux.add_pos_embedding
bool
TRUE
if positional embedding should be used.rec_dropout
double
ranging between 0 and lower 1, determining the dropout between bidirectional recurrent layers.repeat_encoder
int
determining how many times the encoder should be added to the network.dense_dropout
double
ranging between 0 and lower 1, determining the dropout between dense layers.recurrent_dropout
double
ranging between 0 and lower 1, determining the recurrent dropout for each recurrent layer. Only relevant for keras models.encoder_dropout
double
ranging between 0 and lower 1, determining the dropout for the dense projection within the encoder layers.optimizer
string
"adam"
or"rmsprop"
.
Method train()
Method for training a neural net.
Training includes a routine for early stopping. In the case that loss<0.0001 and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values of the last trained epoch for the remaining epochs.
After training the model with the best values for Average Iota, Accuracy, and Loss on the validation data set is used as the final model.
Usage
TEClassifierProtoNet$train(
data_embeddings,
data_targets,
data_folds = 5,
data_val_size = 0.25,
use_sc = TRUE,
sc_method = "dbsmote",
sc_min_k = 1,
sc_max_k = 10,
use_pl = TRUE,
pl_max_steps = 3,
pl_max = 1,
pl_anchor = 1,
pl_min = 0,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 35,
Ns = 5,
Nq = 3,
loss_alpha = 0.5,
loss_margin = 0.5,
sampling_separate = FALSE,
sampling_shuffle = TRUE,
dir_checkpoint,
trace = TRUE,
ml_trace = 1,
log_dir = NULL,
log_write_interval = 10,
n_cores = auto_n_cores()
)
Arguments
data_embeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.
data_targets
factor
containing the labels for cases stored indata_embeddings
. Factor must be named and has to use the same names used indata_embeddings
.data_folds
int
determining the number of cross-fold samples.data_val_size
double
between 0 and 1, indicating the proportion of cases of each class which should be used for the validation sample during the estimation of the model. The remaining cases are part of the training data.use_sc
bool
TRUE
if the estimation should integrate synthetic cases.FALSE
if not.sc_method
vector
containing the method for generating synthetic cases. Possible aresc_method="adas"
,sc_method="smote"
, andsc_method="dbsmote"
.sc_min_k
int
determining the minimal number of k which is used for creating synthetic units.sc_max_k
int
determining the maximal number of k which is used for creating synthetic units.use_pl
bool
TRUE
if the estimation should integrate pseudo-labeling.FALSE
if not.pl_max_steps
int
determining the maximum number of steps during pseudo-labeling.pl_max
double
between 0 and 1, setting the maximal level of confidence for considering a case for pseudo-labeling.pl_anchor
double
between 0 and 1 indicating the reference point for sorting the new cases of every label. See notes for more details.pl_min
double
between 0 and 1, setting the minimal level of confidence for considering a case for pseudo-labeling.sustain_track
bool
IfTRUE
energy consumption is tracked during training via the python library 'codecarbon'.sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if sustainability should be tracked. A list can be found on Wikipedia: https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.sustain_region
Region within a country. Only available for USA and Canada See the documentation of codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html
sustain_interval
int
Interval in seconds for measuring power usage.epochs
int
Number of training epochs.batch_size
int
Size of the batches for training.Ns
int
Number of cases for every class in the sample.Nq
int
Number of cases for every class in the query.loss_alpha
double
Value between 0 and 1 indicating how strong the loss should focus on pulling cases to its corresponding prototypes or pushing cases away from other prototypes. The higher the value the more the loss concentrates on pulling cases to its corresponding prototypes.loss_margin
double
Value greater 0 indicating the minimal distance of every case from prototypes of other classessampling_separate
bool
IfTRUE
the cases for every class are divided into a data set for sample and for query. These are never mixed. IfTRUE
sample and query cases are drawn from the same data pool. That is, a case can be part of sample in one epoch and in another epoch it can be part of query. It is ensured that a case is never part of sample and query at the same time. In addition, it is ensured that every cases exists only once during a training step.sampling_shuffle
bool
IfTRUE
cases a randomly drawn from the data during every step. IfFALSE
the cases are not shuffled.dir_checkpoint
string
Path to the directory where the checkpoint during training should be saved. If the directory does not exist, it is created.trace
bool
TRUE
, if information about the estimation phase should be printed to the console.ml_trace
int
ml_trace=0
does not print any information about the training process from pytorch on the console.log_dir
string
Path to the directory where the log files should be saved. If no logging is desired set this argument toNULL
.log_write_interval
int
Time in seconds determining the interval in which the logger should try to update the log files. Only relevant iflog_dir
is notNULL
.n_cores
int
Number of cores which should be used during the calculation of synthetic cases. Only relevant ifuse_sc=TRUE
.balance_class_weights
bool
IfTRUE
class weights are generated based on the frequencies of the training data with the method Inverse Class Frequency'. IfFALSE
each class has the weight 1.balance_sequence_length
bool
IfTRUE
sample weights are generated for the length of sequences based on the frequencies of the training data with the method Inverse Class Frequency'. IfFALSE
each sequences length has the weight 1.
Details
sc_max_k
: All values from sc_min_k up to sc_max_k are successively used. If the number ofsc_max_k
is too high, the value is reduced to a number that allows the calculating of synthetic units.pl_anchor:
With the help of this value, the new cases are sorted. For this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
Method embed()
Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of texts created by an object of class TextEmbeddingModel. These embeddings embed documents according to their similarity to specific classes.
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
batch_size
int
batch size.
Method plot_embeddings()
Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
Usage
TEClassifierProtoNet$plot_embeddings(
embeddings_q,
classes_q = NULL,
batch_size = 12,
alpha = 0.5,
size_points = 3,
size_points_prototypes = 8,
inc_unlabeled = TRUE
)
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
classes_q
Named
factor
containg the true classes for every case. Please note that the names must match the names/ids inembeddings_q
.batch_size
int
batch size.alpha
float
Value indicating how transparent the points should be (important if many points overlap). Does not apply to points representing prototypes.size_points
int
Size of the points excluding the points for prototypes.size_points_prototypes
int
Size of points representing prototypes.inc_unlabeled
bool
IfTRUE
plot includes unlabeled cases as data points.