topics_topwords.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import pandas as pd
  2. import numpy as np
  3. import tomotopy as tp
  4. from os.path import exists
  5. import textwrap
  6. latex_chars = "^+=_"
  7. if not exists("output/top_words.csv"):
  8. mdl = tp.CTModel.load("output/hep-ct-75-0.1-0.001-130000-20/model")
  9. top_words = []
  10. for topic in range(mdl.k):
  11. words = mdl.get_topic_words(topic, 100)
  12. words = [
  13. {
  14. 'topic': topic,
  15. 'word': word,
  16. 'unithood': np.log(2+word.count(' ')),
  17. 'p': p
  18. }
  19. for word, p in words
  20. ]
  21. top_words += words
  22. top_words = pd.DataFrame(top_words)
  23. top_words = top_words[~top_words["word"].str.contains("\\", regex=False)]
  24. top_words["word"] = top_words["word"].apply(
  25. lambda w: (
  26. f"${w}$" if any([c in w for c in latex_chars]) else w
  27. )
  28. )
  29. top_words["word"] = top_words["word"].apply(
  30. lambda w: (
  31. w[:-2] + '$' if w[-2:] == '_$' or w[-2:] == '^$' else w
  32. )
  33. )
  34. top_words['x'] = top_words['p']*top_words['unithood']
  35. top_words = top_words.sort_values(["topic", "x"], ascending=[True, False]).groupby("topic").head(15)
  36. top_words.to_csv("output/top_words.csv")
  37. else:
  38. top_words = pd.read_csv("output/top_words.csv")
  39. top_words = top_words.groupby("topic").agg(
  40. word = ('word', lambda x: ", ".join(x.tolist()))
  41. ).reset_index()
  42. top_words = top_words.merge(pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")[["topic", "description"]])
  43. top_words.rename(columns = {
  44. 'word': 'Most frequent expressions',
  45. "description": "Topic (context)"
  46. }, inplace = True)
  47. top_words.sort_values("Topic (context)", inplace=True)
  48. # top_words["Sujet"] = top_words["Sujet"].apply(lambda s: "\\\\ ".join(textwrap.wrap(s, width=15)))
  49. pd.set_option('display.max_colwidth', None)
  50. latex = top_words.reset_index()[["Topic (context)", "Most frequent expressions"]].set_index(["Topic (context)"]).to_latex(
  51. longtable=True,
  52. sparsify=True,
  53. multirow=True,
  54. multicolumn=True,
  55. position='H',
  56. column_format='p{0.2\\textwidth}|p{0.8\\textwidth}',
  57. escape=False,
  58. caption="Most frequent terms for each topic.",
  59. label="table:top_words"
  60. )
  61. latex = latex.replace('\\\\\n', '\\\\ \\midrule\n')
  62. with open("tables/top_words.tex", "w+") as fp:
  63. fp.write(latex)