embeddings.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import nltk
  2. import re
  3. import numpy as np
  4. import multiprocessing as mp
  5. import itertools
  6. from typing import List
  7. from abc import ABC, abstractmethod
  8. class Embeddings(ABC):
  9. def __init__(self, tokens: List[List[str]]):
  10. self.tokens = tokens
  11. @abstractmethod
  12. def model(self):
  13. pass
  14. @abstractmethod
  15. def recover(self):
  16. pass
  17. class GensimWord2Vec(Embeddings):
  18. def __init__(self, tokens, **kwargs):
  19. super().__init__(tokens)
  20. def model(
  21. self,
  22. vector_size: int = 128,
  23. window: int = 20,
  24. min_count: int = 10,
  25. workers: int = 4,
  26. **kwargs
  27. ):
  28. from gensim.models import word2vec
  29. model = word2vec.Word2Vec(
  30. self.tokens,
  31. vector_size=vector_size,
  32. window=window,
  33. min_count=min_count,
  34. workers=workers,
  35. **kwargs
  36. )
  37. return model
  38. def recover(self, model):
  39. return model
  40. # tokens = self.get_tokens(threads=threads)
  41. # tokens = set(itertools.chain.from_iterable(tokens))
  42. # embeddings = []
  43. # for text in tokens:
  44. # embeddings.append([model.wv[token] for token in text if token in model.wv])
  45. # return embeddings