123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431 |
- from AbstractSemantics.terms import TermExtractor
- from AbstractSemantics.embeddings import GensimWord2Vec
- from gensim.models import KeyedVectors
- import nltk
- import pandas as pd
- import numpy as np
- from scipy.sparse import csr_matrix
- from os.path import join as opj
- from os.path import exists
- import itertools
- from functools import partial
- from collections import defaultdict
- import re
- import multiprocessing as mp
- # from matplotlib import pyplot as plt
- import argparse
- import yaml
- import sys
- import pickle
- from gensim.models.callbacks import CallbackAny2Vec
- class MonitorCallback(CallbackAny2Vec):
- def __init__(self, test_words):
- self._test_words = test_words
- self.epoch = 0
- def on_epoch_end(self, model):
- loss = model.get_latest_training_loss()
- if self.epoch == 0:
- print("Loss after epoch {}: {}".format(self.epoch, loss))
- else:
- print(
- "Loss after epoch {}: {}".format(
- self.epoch, loss - self.loss_previous_step
- )
- )
- self.epoch += 1
- self.loss_previous_step = loss
- for word in self._test_words: # show wv logic changes
- print(f"{word}: {model.wv.most_similar(word)}")
- def filter_ngrams(l, wl):
- return [ngram for ngram in l if ngram in wl]
- def construct_bow(l, n):
- items = list(set(l))
- return np.array(items), np.array([l.count(i) for i in items])
- def ngram_inclusion(i, js):
- return [
- j.find(i) >= 0 # matching
- and bool(re.search(f"(^|\_){re.escape(i)}($|\_)", j))
- and (j.count("_") == i.count("_") + 1)
- and (
- (i.count("_") >= 1)
- or bool(re.search(f"(^|\_){re.escape(i)}$", j))
- or bool(re.search(f"^{re.escape(i)}($|\_)", j))
- )
- for j in js
- ]
- if __name__ == "__main__":
- parser = argparse.ArgumentParser("CT Model")
- parser.add_argument("location", help="model directory")
- parser.add_argument(
- "filter", choices=["categories", "keywords", "no-filter"], help="filter type"
- )
- parser.add_argument("--values", nargs="+", default=[], help="filter allowed values")
- parser.add_argument("--dataset", default="inspire-harvest/database")
- # sample size
- parser.add_argument("--samples", type=int, default=50000)
- parser.add_argument("--constant-sampling", type=int, default=0)
- # text pre-processing
- parser.add_argument(
- "--add-title", default=False, action="store_true", help="include title"
- )
- parser.add_argument(
- "--remove-latex", default=False, action="store_true", help="remove latex"
- )
- parser.add_argument(
- "--lemmatize", default=False, action="store_true", help="lemmatize"
- )
- parser.add_argument(
- "--limit-redundancy",
- default=False,
- action="store_true",
- help="limit redundancy",
- )
- parser.add_argument("--blacklist", default=None, help="blacklist")
- # embeddings
- parser.add_argument("--dimensions", type=int, default=50)
- parser.add_argument("--pre-trained-embeddings", default=False, action="store_true")
- parser.add_argument("--use-saved-embeddings", default=False, action="store_true")
- # topic model parameters
- parser.add_argument("--topics", type=int, default=25)
- parser.add_argument("--min-df", type=float, default=0.001)
- parser.add_argument("--max-df", type=float, default=0.15)
- parser.add_argument("--threads", type=int, default=4)
- args = parser.parse_args(
- [
- "output/etm_25_pretrained",
- "categories",
- "--values",
- "Theory-HEP",
- "Phenomenology-HEP",
- "--dataset",
- "../inspire-harvest/database",
- "--constant-sampling",
- "30000",
- "--samples",
- "300000",
- "--threads",
- "24",
- "--add-title",
- "--remove-latex",
- "--dimensions",
- "50",
- "--topics",
- "25",
- "--min-df",
- "0.00075",
- "--lemmatize",
- "--pre-trained-embeddings",
- # "--limit-redundancy"
- "--use-saved-embeddings"
- # "--blacklist",
- # "output/medialab/blacklist",
- ]
- )
- with open(opj(args.location, "params.yml"), "w+") as fp:
- yaml.dump(args, fp)
- articles = pd.read_parquet(
- opj(args.dataset, "articles.parquet")
- )[["title", "abstract", "article_id", "date_created", "categories"]]
- if args.add_title:
- articles["abstract"] = articles["abstract"].str.cat(articles["title"], sep=". ")
- articles.drop(columns=["title"], inplace=True)
- if args.remove_latex:
- articles["abstract"] = articles["abstract"].apply(
- lambda s: re.sub("$[^>]+$", "", s)
- )
- articles["abstract"] = articles["abstract"].apply(
- lambda s: re.sub(r"\b\\\w+", "", s)
- )
- articles["abstract"] = articles["abstract"].apply(
- lambda s: re.sub("[^0-9a-zA-Z--- -\.]+", "", s)
- )
- # articles["abstract"] = articles["abstract"].str.replace("-", " ") # NEW
- articles = articles[articles["abstract"].map(len) >= 100]
- articles["abstract"] = articles["abstract"].str.lower()
- articles = articles[articles["date_created"].str.len() >= 4]
- if "year" not in articles.columns:
- articles["year"] = articles["date_created"].str[:4].astype(int) - 2000
- articles = articles[(articles["year"] >= 0) & (articles["year"] <= 20)]
- else:
- articles["year"] = articles["year"].astype(int)
- articles = articles[articles["year"]>=2000]
-
- articles["year_group"] = articles["year"] // 5
- keep = pd.Series([False] * len(articles), index=articles.index)
- print("Applying filter...")
- if args.filter == "keywords":
- for value in args.values:
- keep |= articles["abstract"].str.contains(value)
- elif args.filter == "categories":
- for value in args.values:
- keep |= articles["categories"].apply(lambda l: value in l)
- elif args.filter == "no-filter":
- keep |= True
- articles = articles[keep == True].sample(frac=1)
- if args.constant_sampling > 0:
- articles = articles.groupby("year").head(args.constant_sampling)
- articles = articles.sample(frac=1).head(args.samples)
- articles.reset_index(inplace=True)
- articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
- print(articles)
- print("Extracting n-grams...")
- extractor = TermExtractor(
- articles["abstract"].tolist(),
- limit_redundancy=args.limit_redundancy,
- patterns=[
- ["JJ.*"],
- ["NN.*"],
- ["JJ.*", "NN.*"],
- ["JJ.*", "NN.*", "NN.*"],
- # ["JJ.*", "NN", "CC", "NN.*"],
- # ["JJ.*", "NN.*", "JJ.*", "NN.*"],
- # ["RB.*", "JJ.*", "NN.*", "NN.*"],
- ],
- )
- ngrams = extractor.ngrams(
- threads=args.threads,
- lemmatize=args.lemmatize,
- lemmatize_ngrams=args.lemmatize,
- split_sentences=args.pre_trained_embeddings and not args.use_saved_embeddings,
- )
- del extractor
- del articles["abstract"]
- if args.pre_trained_embeddings and not args.use_saved_embeddings:
- ngrams = map(
- lambda l: [
- [
- ("_".join(n))
- .strip()
- .replace("-", "_")
- .replace(".._", "")
- .replace("_..", "")
- for n in sent
- ]
- for sent in l
- ],
- ngrams,
- )
- ngrams = list(ngrams)
- print("Pre-training embeddings...")
- emb = GensimWord2Vec(
- [sentence for sentences in ngrams for sentence in sentences]
- )
- model = emb.create_model(
- vector_size=args.dimensions,
- window=5,
- workers=args.threads,
- compute_loss=True,
- # epochs=90,
- # min_count=30,
- epochs=80,
- min_count=15,
- sg=1,
- callbacks=[
- MonitorCallback(
- [
- "transformer",
- "embedding",
- "syntax",
- "grammar"
- ]
- )
- ],
- )
- model.wv.save_word2vec_format(opj(args.location, "embeddings.bin"), binary=True)
- del model
- ngrams = [
- list(itertools.chain.from_iterable(article_sentences))
- for article_sentences in ngrams
- ]
- else:
- ngrams = map(
- lambda l: [
- "_".join(n)
- .strip()
- .replace("-", "_")
- .replace(".._", "")
- .replace("_..", "")
- for n in l
- ],
- ngrams,
- )
- ngrams = list(ngrams)
- print("Deriving vocabulary...")
- voc = defaultdict(int)
- for article_ngrams in ngrams:
- _ngrams = set(article_ngrams)
- for ngram in _ngrams:
- voc[ngram] += 1
- voc = pd.DataFrame({"ngram": voc.keys(), "count": voc.values()})
- voc["df"] = voc["count"] / len(articles)
- voc.set_index("ngram", inplace=True)
- if args.min_df < 1:
- voc = voc[voc["df"] >= args.min_df]
- else:
- voc = voc[voc["count"] >= args.min_df]
- if args.max_df < 1:
- voc = voc[voc["df"] <= args.max_df]
- else:
- voc = voc[voc["count"] <= args.max_df]
- voc["len"] = voc.index.map(len)
- voc = voc[voc["len"] >= 2]
- stop_words = nltk.corpus.stopwords.words("english")
- voc = voc[~voc.index.isin(stop_words)]
- if args.blacklist is not None:
- print("Filtering black-listed keywords...")
- blacklist = pd.read_csv(args.blacklist)["ngram"].tolist()
- voc = voc[
- voc.index.map(lambda s: not any([ngram in s for ngram in blacklist]))
- == True
- ]
- print("Filtering completed.")
- voc = voc.sort_values("df", ascending=False)
- voc.to_csv(opj(args.location, "ngrams.csv"))
- voc = pd.read_csv(opj(args.location, "ngrams.csv"), keep_default_na=False)[
- "ngram"
- ].tolist()
- vocabulary = {n: i for i, n in enumerate(voc)}
- print("Filtering n-grams...")
- with mp.Pool(processes=args.threads) as pool:
- ngrams = pool.map(partial(filter_ngrams, wl=voc), ngrams)
- print("Constructing bag-of-words...")
- bow = [[vocabulary[ngram] for ngram in _ngrams] for _ngrams in ngrams]
- # if args.limit_redundancy:
- # print("Building 'within' matrix...")
- # with mp.Pool(processes=args.threads) as pool:
- # within = pool.map(partial(ngram_inclusion, js=voc), voc)
- # within = np.array(within).astype(int)
- # print("Removing double-counting...")
- # bow = csr_matrix(bow)
- # double_counting = bow.dot(csr_matrix(within.T))
- # bow = bow - double_counting
- # print(double_counting.sum(), "redundant keywords removed")
- # del double_counting
- # bow = bow.todense()
- # print((bow <= -1).sum(), "keywords had negative counts after removal")
- del ngrams
- with mp.Pool(processes=args.threads) as pool:
- bow = pool.map(partial(construct_bow, n=len(voc)), bow)
- keep = [i for i in range(len(bow)) if len(bow[i][0]) > 0]
- articles = articles.iloc[keep]
- articles[["article_id"]].to_csv(opj(args.location, "articles.csv"))
- bow = [bow[i] for i in keep]
- dataset = {
- "tokens": [bow[i][0] for i in range(len(bow))],
- "counts": [bow[i][1] for i in range(len(bow))],
- "article_id": articles["article_id"],
- }
- del bow
- with open(opj(args.location, "dataset.pickle"), "wb") as handle:
- pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
- print("Training...")
- from embedded_topic_model.models.etm import ETM
- etm_instance = ETM(
- voc,
- num_topics=args.topics,
- rho_size=args.dimensions,
- emb_size=args.dimensions,
- epochs=25,
- debug_mode=True,
- train_embeddings=not args.pre_trained_embeddings,
- model_path=opj(args.location, "model"),
- embeddings=opj(args.location, "embeddings.bin")
- if args.pre_trained_embeddings
- else None,
- use_c_format_w2vec=True,
- )
- etm_instance.fit(dataset)
- with open(opj(args.location, "etm_instance.pickle"), "wb") as handle:
- pickle.dump(etm_instance, handle, protocol=pickle.HIGHEST_PROTOCOL)
- topics = etm_instance.get_topics(20)
- print(topics)
- topic_coherence = etm_instance.get_topic_coherence()
- print(topic_coherence)
- topic_diversity = etm_instance.get_topic_diversity()
- print(topic_diversity)
|