Text embedding classifier with a ProtoNet
Source:R/obj_TEClassifierProtoNet.R
TEClassifierProtoNet.RdAbstract class for neural nets with 'pytorch'.
This class is deprecated. Please use an Object of class TEClassifierSequentialPrototype instead.
This object represents in implementation of a prototypical network for few-shot learning as described by Snell, Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
Value
Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
classifier an object of class EmbeddedText or LargeDataSetForTextEmbeddings and a factor are necessary. The
object of class EmbeddedText or LargeDataSetForTextEmbeddings contains the numerical text representations (text
embeddings) of the raw texts generated by an object of class TextEmbeddingModel. The factor contains the
classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
class EmbeddedText or LargeDataSetForTextEmbeddings has to be used which was created with the same
TextEmbeddingModel as for training.
Note
This model requires pad_value=0. If this condition is not met the
padding value is switched automatically.
References
Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning. https://doi.org/10.48550/arXiv.1703.05175
Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H. Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing. https://doi.org/10.1007/978-3-030-16145-3_24
Super classes
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> aifeducation::ModelsBasedOnTextEmbeddings -> aifeducation::ClassifiersBasedOnTextEmbeddings -> aifeducation::TEClassifiersBasedOnProtoNet -> TEClassifierProtoNet
Methods
Inherited methods
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()aifeducation::AIFEMaster$set_publication_info()aifeducation::AIFEBaseModel$count_parameter()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()aifeducation::ClassifiersBasedOnTextEmbeddings$adjust_target_levels()aifeducation::ClassifiersBasedOnTextEmbeddings$check_embedding_model()aifeducation::ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type()aifeducation::ClassifiersBasedOnTextEmbeddings$load_from_disk()aifeducation::ClassifiersBasedOnTextEmbeddings$plot_coding_stream()aifeducation::ClassifiersBasedOnTextEmbeddings$plot_training_history()aifeducation::ClassifiersBasedOnTextEmbeddings$predict()aifeducation::ClassifiersBasedOnTextEmbeddings$requires_compression()aifeducation::ClassifiersBasedOnTextEmbeddings$save()aifeducation::TEClassifiersBasedOnProtoNet$get_metric_scale_factor()aifeducation::TEClassifiersBasedOnProtoNet$predict_with_samples()aifeducation::TEClassifiersBasedOnProtoNet$train()
Method configure()
Creating a new instance of this class.
Usage
TEClassifierProtoNet$configure(
name = NULL,
label = NULL,
text_embeddings = NULL,
feature_extractor = NULL,
target_levels = NULL,
dense_size = 4L,
dense_layers = 0L,
rec_size = 4L,
rec_layers = 2L,
rec_type = "GRU",
rec_bidirectional = FALSE,
embedding_dim = 2L,
self_attention_heads = 0L,
intermediate_size = NULL,
attention_type = "Fourier",
add_pos_embedding = TRUE,
act_fct = "ELU",
parametrizations = "None",
rec_dropout = 0.1,
repeat_encoder = 1L,
dense_dropout = 0.4,
encoder_dropout = 0.1
)Arguments
namestringName of the new model. Please refer to common name conventions. Free text can be used with parameterlabel. If set toNULLa unique ID is generated automatically. Allowed values: anylabelstringLabel for the new model. Here you can use free text. Allowed values: anytext_embeddingsEmbeddedText, LargeDataSetForTextEmbeddingsObject of class EmbeddedText or LargeDataSetForTextEmbeddings.feature_extractorTEFeatureExtractorObject of class TEFeatureExtractor which should be used in order to reduce the number of dimensions of the text embeddings. If no feature extractor should be applied setNULL.target_levelsvectorcontaining the levels (categories or classes) within the target data. Please note that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels indicating a higher category/class. For nominal data the order does not matter.dense_sizeintNumber of neurons for each dense layer. Allowed values:1 <= xdense_layersintNumber of dense layers. Allowed values:0 <= xrec_sizeintNumber of neurons for each recurrent layer. Allowed values:1 <= xrec_layersintNumber of recurrent layers. Allowed values:0 <= xrec_typestringType of the recurrent layers.rec_type='GRU'for Gated Recurrent Unit andrec_type='LSTM'for Long Short-Term Memory. Allowed values: 'GRU', 'LSTM'rec_bidirectionalboolIfTRUEa bidirectional version of the recurrent layers is used.embedding_dimintdetermining the number of dimensions for the embedding. Allowed values:2 <= xself_attention_headsintdetermining the number of attention heads for a self-attention layer. Only relevant ifattention_type='multihead'Allowed values:0 <= xintermediate_sizeintdetermining the size of the projection layer within a each transformer encoder. Allowed values:1 <= xattention_typestringChoose the attention type. Allowed values: 'Fourier', 'MultiHead'add_pos_embeddingboolTRUEif positional embedding should be used.act_fctstringActivation function for all layers. Allowed values: 'ELU', 'LeakyReLU', 'ReLU', 'GELU', 'Sigmoid', 'Tanh', 'PReLU'parametrizationsstringRe-Parametrizations of the weights of layers. Allowed values: 'None', 'OrthogonalWeights', 'WeightNorm', 'SpectralNorm'rec_dropoutdoubledetermining the dropout between recurrent layers. Allowed values:0 <= x <= 0.6repeat_encoderintdetermining how many times the encoder should be added to the network. Allowed values:0 <= xdense_dropoutdoubledetermining the dropout between dense layers. Allowed values:0 <= x <= 0.6encoder_dropoutdoubledetermining the dropout for the dense projection within the transformer encoder layers. Allowed values:0 <= x <= 0.6biasboolIfTRUEa bias term is added to all layers. IfFALSEno bias term is added to the layers.
Method embed()
Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of texts created by an object of class TextEmbeddingModel. These embeddings embed documents according to their similarity to specific classes.
Arguments
embeddings_qObject of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
batch_sizeintbatch size.
Method plot_embeddings()
Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
Usage
TEClassifierProtoNet$plot_embeddings(
embeddings_q,
classes_q = NULL,
batch_size = 12L,
alpha = 0.5,
size_points = 3L,
size_points_prototypes = 8L,
inc_unlabeled = TRUE
)Arguments
embeddings_qObject of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
classes_qNamed
factorcontaing the true classes for every case. Please note that the names must match the names/ids inembeddings_q.batch_sizeintbatch size.alphafloatValue indicating how transparent the points should be (important if many points overlap). Does not apply to points representing prototypes.size_pointsintSize of the points excluding the points for prototypes.size_points_prototypesintSize of points representing prototypes.inc_unlabeledboolIfTRUEplot includes unlabeled cases as data points.