embeddings.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import nltk
  2. import re
  3. import numpy as np
  4. import multiprocessing as mp
  5. import itertools
  6. from typing import List
  7. from abc import ABC, abstractmethod
  8. class Embeddings(ABC):
  9. def __init__(self, tokens: List[List[str]]):
  10. self.tokens = tokens
  11. @abstractmethod
  12. def train(self):
  13. pass
  14. @abstractmethod
  15. def recover(self):
  16. pass
  17. class GensimWord2Vec(Embeddings):
  18. def __init__(self, tokens, **kwargs):
  19. super().__init__(tokens)
  20. def train(
  21. self,
  22. vector_size: int = 128,
  23. window: int = 20,
  24. min_count: int = 10,
  25. threads: int = 4,
  26. **kwargs
  27. ):
  28. from gensim.models import word2vec
  29. model = word2vec.Word2Vec(
  30. self.tokens,
  31. vector_size=vector_size,
  32. window=window,
  33. min_count=min_count,
  34. workers=threads,
  35. **kwargs
  36. )
  37. return model
  38. def recover(self, model):
  39. tokens = self.get_tokens(threads=threads)
  40. tokens = set(itertools.chain.from_iterable(tokens))
  41. embeddings = []
  42. for text in tokens:
  43. embeddings.append([model.wv[token] for token in text if token in model.wv])
  44. return embeddings