Source code for shorttext.cli.categorization


import os
from functools import partial
from argparse import ArgumentParser
from operator import itemgetter

from loguru import logger

from ..utils.compactmodel_io import get_model_classifier_name
from ..utils.classification_exceptions import AlgorithmNotExistException, WordEmbeddingModelNotExistException
from ..utils import load_word2vec_model, load_fasttext_model, load_poincare_model
from ..smartload import smartload_compact_model
from ..classifiers import TopicVectorCosineDistanceClassifier


# configs
allowed_classifiers = [
    'ldatopic', 'lsitopic', 'rptopic', 'kerasautoencoder',
    'topic_sklearn', 'nnlibvec', 'sumvec', 'maxent'
]
needembedded_classifiers = ['nnlibvec', 'sumvec']
topicmodels = ['ldatopic', 'lsitopic', 'rptopic', 'kerasautoencoder']


# lazy functions for loading word embedding model
load_word2vec_nonbinary_model = partial(load_word2vec_model, binary=False)
load_poincare_binary_model = partial(load_poincare_model, binary=True)

typedict = {
    'word2vec': load_word2vec_model,
    'word2vec_nonbinary': load_word2vec_nonbinary_model,
    'fasttext': load_fasttext_model,
    'poincare': load_poincare_model,
    'poincare_binary': load_poincare_binary_model
}


[docs] def get_argparser() -> ArgumentParser: """Get argument parser for short text categorization CLI. Returns: ArgumentParser for command line arguments. """ parser = ArgumentParser( description='Perform prediction on short text with a given trained model.' ) parser.add_argument('model_filepath', help='Path of the trained (compact) model.') parser.add_argument('--wv', default='', help='Path of the pre-trained Word2Vec model.') parser.add_argument('--vecsize', default=300, type=int, help='Vector dimensions. (Default: 300)') parser.add_argument('--topn', type=int, default=10, help='Number of top results to show.') parser.add_argument('--inputtext', default=None, help='Single input text for classification. If omitted, will enter console mode.') parser.add_argument('--type', default='word2vec', choices=typedict.keys(), help='Type of word-embedding model (default: word2vec)') return parser
# main block
[docs] def main(): # argument parsing args = get_argparser().parse_args() # check if the model file is given if not os.path.exists(args.model_filepath): raise IOError(f'Model file "{args.model_filepath}" not found!') # get the name of the classifier logger.info('Retrieving classifier name...') classifier_name = get_model_classifier_name(args.model_filepath) if classifier_name not in allowed_classifiers: raise AlgorithmNotExistException(classifier_name) # load the Word2Vec model if necessary wvmodel = None if classifier_name in needembedded_classifiers: # check if the word embedding model is available if not os.path.exists(args.wv): raise WordEmbeddingModelNotExistException(args.wv) # if there, load it logger.info(f'Loading word-embedding model from {args.wv}...') wvmodel = typedict[args.type](args.wv) # load the classifier logger.info('Initializing the classifier...') if classifier_name in topicmodels: topicmodel = smartload_compact_model(args.model_filepath, wvmodel, vecsize=args.vecsize) classifier = TopicVectorCosineDistanceClassifier(topicmodel) else: classifier = smartload_compact_model(args.model_filepath, wvmodel, vecsize=args.vecsize) # predict single input or run in console mode if args.inputtext is not None: if len(args.inputtext.strip()) == 0: print('No input text provided.') return scoredict = classifier.score(args.inputtext) for label, score in sorted(scoredict.items(), key=itemgetter(1), reverse=True)[:args.topn]: print(f'{label} : {score:.4f}') else: # Console print('Enter text to classify (empty input to quit):') while True: shorttext = input('text> ').strip() if not shorttext: break scoredict = classifier.score(shorttext) for label, score in sorted(scoredict.items(), key=itemgetter(1), reverse=True)[:args.topn]: print(f'{label} : {score:.4f}') print('Done.')