
Text embedding classifier with a ProtoNet
Source:R/obj_TEClassifierProtoNet.R
TEClassifierProtoNet.Rd
Abstract class for neural nets with 'pytorch'.
This class is deprecated. Please use an Object of class TEClassifierSequentialPrototype instead.
This object represents in implementation of a prototypical network for few-shot learning as described by Snell, Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
Value
Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
classifier an object of class EmbeddedText or LargeDataSetForTextEmbeddings and a factor
are necessary. The
object of class EmbeddedText or LargeDataSetForTextEmbeddings contains the numerical text representations (text
embeddings) of the raw texts generated by an object of class TextEmbeddingModel. The factor
contains the
classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
class EmbeddedText or LargeDataSetForTextEmbeddings has to be used which was created with the same
TextEmbeddingModel as for training.
Note
This model requires pad_value=0
. If this condition is not met the
padding value is switched automatically.
References
Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning. https://doi.org/10.48550/arXiv.1703.05175
Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H. Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing. https://doi.org/10.1007/978-3-030-16145-3_24
Super classes
aifeducation::AIFEBaseModel
-> aifeducation::ModelsBasedOnTextEmbeddings
-> aifeducation::ClassifiersBasedOnTextEmbeddings
-> aifeducation::TEClassifiersBasedOnProtoNet
-> TEClassifierProtoNet
Methods
Inherited methods
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$is_trained()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model()
aifeducation::ModelsBasedOnTextEmbeddings$get_text_embedding_model_name()
aifeducation::ClassifiersBasedOnTextEmbeddings$adjust_target_levels()
aifeducation::ClassifiersBasedOnTextEmbeddings$check_embedding_model()
aifeducation::ClassifiersBasedOnTextEmbeddings$check_feature_extractor_object_type()
aifeducation::ClassifiersBasedOnTextEmbeddings$load_from_disk()
aifeducation::ClassifiersBasedOnTextEmbeddings$plot_coding_stream()
aifeducation::ClassifiersBasedOnTextEmbeddings$plot_training_history()
aifeducation::ClassifiersBasedOnTextEmbeddings$predict()
aifeducation::ClassifiersBasedOnTextEmbeddings$requires_compression()
aifeducation::ClassifiersBasedOnTextEmbeddings$save()
aifeducation::TEClassifiersBasedOnProtoNet$get_metric_scale_factor()
aifeducation::TEClassifiersBasedOnProtoNet$predict_with_samples()
aifeducation::TEClassifiersBasedOnProtoNet$train()
Method configure()
Creating a new instance of this class.
Usage
TEClassifierProtoNet$configure(
name = NULL,
label = NULL,
text_embeddings = NULL,
feature_extractor = NULL,
target_levels = NULL,
dense_size = 4,
dense_layers = 0,
rec_size = 4,
rec_layers = 2,
rec_type = "GRU",
rec_bidirectional = FALSE,
embedding_dim = 2,
self_attention_heads = 0,
intermediate_size = NULL,
attention_type = "Fourier",
add_pos_embedding = TRUE,
act_fct = "ELU",
parametrizations = "None",
rec_dropout = 0.1,
repeat_encoder = 1,
dense_dropout = 0.4,
encoder_dropout = 0.1
)
Arguments
name
string
Name of the new model. Please refer to common name conventions. Free text can be used with parameterlabel
. If set toNULL
a unique ID is generated automatically. Allowed values: anylabel
string
Label for the new model. Here you can use free text. Allowed values: anytext_embeddings
EmbeddedText, LargeDataSetForTextEmbeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.feature_extractor
TEFeatureExtractor
Object of class TEFeatureExtractor which should be used in order to reduce the number of dimensions of the text embeddings. If no feature extractor should be applied setNULL
.target_levels
vector
containing the levels (categories or classes) within the target data. Please note that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels indicating a higher category/class. For nominal data the order does not matter.dense_size
int
Number of neurons for each dense layer. Allowed values:1 <= x
dense_layers
int
Number of dense layers. Allowed values:0 <= x
rec_size
int
Number of neurons for each recurrent layer. Allowed values:1 <= x
rec_layers
int
Number of recurrent layers. Allowed values:0 <= x
rec_type
string
Type of the recurrent layers.rec_type='GRU'
for Gated Recurrent Unit andrec_type='LSTM'
for Long Short-Term Memory. Allowed values: 'GRU', 'LSTM'rec_bidirectional
bool
IfTRUE
a bidirectional version of the recurrent layers is used.embedding_dim
int
determining the number of dimensions for the embedding. Allowed values:2 <= x
self_attention_heads
int
determining the number of attention heads for a self-attention layer. Only relevant ifattention_type='multihead'
Allowed values:0 <= x
intermediate_size
int
determining the size of the projection layer within a each transformer encoder. Allowed values:1 <= x
attention_type
string
Choose the attention type. Allowed values: 'Fourier', 'MultiHead'add_pos_embedding
bool
TRUE
if positional embedding should be used.act_fct
string
Activation function for all layers. Allowed values: 'ELU', 'LeakyReLU', 'ReLU', 'GELU', 'Sigmoid', 'Tanh', 'PReLU'parametrizations
string
Re-Parametrizations of the weights of layers. Allowed values: 'None', 'OrthogonalWeights', 'WeightNorm', 'SpectralNorm'rec_dropout
double
determining the dropout between recurrent layers. Allowed values:0 <= x <= 0.6
repeat_encoder
int
determining how many times the encoder should be added to the network. Allowed values:0 <= x
dense_dropout
double
determining the dropout between dense layers. Allowed values:0 <= x <= 0.6
encoder_dropout
double
determining the dropout for the dense projection within the transformer encoder layers. Allowed values:0 <= x <= 0.6
bias
bool
IfTRUE
a bias term is added to all layers. IfFALSE
no bias term is added to the layers.
Method embed()
Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of texts created by an object of class TextEmbeddingModel. These embeddings embed documents according to their similarity to specific classes.
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
batch_size
int
batch size.
Method plot_embeddings()
Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
Usage
TEClassifierProtoNet$plot_embeddings(
embeddings_q,
classes_q = NULL,
batch_size = 12,
alpha = 0.5,
size_points = 3,
size_points_prototypes = 8,
inc_unlabeled = TRUE
)
Arguments
embeddings_q
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings for all cases which should be embedded into the classification space.
classes_q
Named
factor
containg the true classes for every case. Please note that the names must match the names/ids inembeddings_q
.batch_size
int
batch size.alpha
float
Value indicating how transparent the points should be (important if many points overlap). Does not apply to points representing prototypes.size_points
int
Size of the points excluding the points for prototypes.size_points_prototypes
int
Size of points representing prototypes.inc_unlabeled
bool
IfTRUE
plot includes unlabeled cases as data points.