12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import pandas as pd
- import numpy as np
- import tomotopy as tp
- from os.path import exists
- import textwrap
- latex_chars = "^+=_"
- if not exists("output/top_words.csv"):
- mdl = tp.CTModel.load("output/hep-ct-75-0.1-0.001-130000-20/model")
- top_words = []
- for topic in range(mdl.k):
- words = mdl.get_topic_words(topic, 100)
- words = [
- {
- 'topic': topic,
- 'word': word,
- 'unithood': np.log(2+word.count(' ')),
- 'p': p
- }
- for word, p in words
- ]
- top_words += words
- top_words = pd.DataFrame(top_words)
- top_words = top_words[~top_words["word"].str.contains("\\", regex=False)]
- top_words["word"] = top_words["word"].apply(
- lambda w: (
- f"${w}$" if any([c in w for c in latex_chars]) else w
- )
- )
- top_words["word"] = top_words["word"].apply(
- lambda w: (
- w[:-2] + '$' if w[-2:] == '_$' or w[-2:] == '^$' else w
- )
- )
- top_words['x'] = top_words['p']*top_words['unithood']
- top_words = top_words.sort_values(["topic", "x"], ascending=[True, False]).groupby("topic").head(15)
- top_words.to_csv("output/top_words.csv")
- else:
- top_words = pd.read_csv("output/top_words.csv")
- top_words = top_words.groupby("topic").agg(
- word = ('word', lambda x: ", ".join(x.tolist()))
- ).reset_index()
- top_words = top_words.merge(pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")[["topic", "description"]])
- top_words.rename(columns = {
- 'word': 'Most frequent expressions',
- "description": "Topic (context)"
- }, inplace = True)
- top_words.sort_values("Topic (context)", inplace=True)
- # top_words["Sujet"] = top_words["Sujet"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=15)))
- pd.set_option('display.max_colwidth', None)
- latex = top_words.reset_index()[["Topic (context)", "Most frequent expressions"]].set_index(["Topic (context)"]).to_latex(
- longtable=True,
- sparsify=True,
- multirow=True,
- multicolumn=True,
- position='H',
- column_format='p{0.2\\textwidth}|p{0.8\\textwidth}',
- escape=False,
- caption="Most frequent terms for each topic.",
- label="table:top_words"
- )
- latex = latex.replace('\\\\\n', '\\\\ \\midrule\n')
- with open("tables/top_words.tex", "w+") as fp:
- fp.write(latex)
|