Source code for shorttext.classifiers.embed.sumvec.SumEmbedVecClassification


import pickle
from collections import defaultdict
from typing import Optional, Annotated, Self

import numpy as np
import numpy.typing as npt
from gensim.models.keyedvectors import KeyedVectors
from deprecation import deprecated

from ....utils.classification_exceptions import ModelNotTrainedException
from ....utils import shorttext_to_avgvec
from ....utils.compactmodel_io import CompactIOMachine
from ....utils.compute import cosine_similarity


[docs] class SumEmbeddedVecClassifier(CompactIOMachine): """Classifier using summed word embeddings. Each class is represented as the sum of word embeddings for its training sentences, normalized to a unit vector. Prediction uses cosine similarity between the input vector and class centroids. Reference: Pre-trained Word2Vec: https://code.google.com/archive/p/word2vec/ """
[docs] def __init__( self, wvmodel: KeyedVectors, vecsize: Optional[int] = None, simfcn: Optional[callable] = None ): """Initialize the classifier. Args: wvmodel: Word embedding model (e.g., Word2Vec). vecsize: Vector size. Default: None (extracted from model). simfcn: Similarity function. Default: cosine_similarity. """ CompactIOMachine.__init__( self, {'classifier': 'sumvec'}, 'sumvec', ['_embedvecdict.pkl'] ) self.wvmodel = wvmodel self.vecsize = self.wvmodel.vector_size if vecsize is None else vecsize self.simfcn = simfcn if simfcn is not None else cosine_similarity self.trained = False
[docs] def train(self, classdict: dict[str, list[str]]) -> None: """Train the classifier. Args: classdict: Training data with class labels as keys and texts as values. Raises: ModelNotTrainedException: If not trained or loaded. """ self.addvec = defaultdict(lambda : np.zeros(self.vecsize)) for classtype in classdict: self.addvec[classtype] = np.sum( [ self.shorttext_to_embedvec(shorttext) for shorttext in classdict[classtype] ], axis=0 ) self.addvec[classtype] /= np.linalg.norm(self.addvec[classtype]) self.addvec = dict(self.addvec) self.trained = True
[docs] def savemodel(self, nameprefix: str) -> None: """Save the trained model. Args: nameprefix: Prefix for output files. Raises: ModelNotTrainedException: If not trained. """ if not self.trained: raise ModelNotTrainedException() pickle.dump(self.addvec, open(nameprefix+'_embedvecdict.pkl', 'wb'))
[docs] def loadmodel(self, nameprefix: str) -> None: """Load a trained model. Args: nameprefix: Prefix for input files. """ self.addvec = pickle.load(open(nameprefix+'_embedvecdict.pkl', 'rb')) self.trained = True
[docs] def shorttext_to_embedvec( self, shorttext: str ) -> Annotated[npt.NDArray[np.float64], "1D Array"]: """Convert short text to embedding vector. Args: shorttext: Input text. Returns: Normalized embedding vector. """ return shorttext_to_avgvec(shorttext, self.wvmodel)
[docs] def score(self, shorttext: str) -> dict[str, float]: """Calculate classification scores for all class labels. Args: shorttext: Input text. Returns: Dictionary mapping class labels to scores. Raises: ModelNotTrainedException: If not trained. """ if not self.trained: raise ModelNotTrainedException() vec = self.shorttext_to_embedvec(shorttext) scoredict = {} for classtype, addvec in self.addvec.items(): try: scoredict[classtype] = self.simfcn(vec, addvec) except ValueError: scoredict[classtype] = np.nan return scoredict
[docs] @classmethod def from_pretrained( cls, wvmodel: KeyedVectors, name: str, compact: bool = True, vecsize: Optional[int] = None ) -> Self: """Load a SumEmbeddedVecClassifier from file. Args: wvmodel: Word embedding model. name: Model name (compact) or prefix (non-compact). compact: Whether to load compact model. Default: True. vecsize: Vector size. Default: None. Returns: SumEmbeddedVecClassifier instance. """ classifier = SumEmbeddedVecClassifier(wvmodel, vecsize=vecsize) if compact: classifier.load_compact_model(name) else: classifier.loadmodel(name) return classifier
[docs] @deprecated(deprecated_in="4.0.1", removed_in="5.0.0") def load_sumword2vec_classifier( wvmodel: KeyedVectors, name: str, compact: bool = True, vecsize: Optional[int] = None ) -> SumEmbeddedVecClassifier: """ Deprecated. Use `~SumEmbeddedVecClassifier.from_pretrained`. """ return SumEmbeddedVecClassifier.from_pretrained( wvmodel, name, compact=compact, vecsize=vecsize )