Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

etm_ei.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import pandas as pd
  2. import numpy as np
  3. from cmdstanpy import CmdStanModel
  4. import argparse
  5. import pickle
  6. from os.path import join as opj, basename
  7. import seaborn as sns
  8. from matplotlib import pyplot as plt
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("--input", default="output/etm_20_pretrained")
  11. parser.add_argument("--portfolios", default=None)
  12. parser.add_argument("--resources", default=None)
  13. parser.add_argument("--n", type=int, default=4000)
  14. parser.add_argument("--min-pop", type=int, default=100)
  15. parser.add_argument("--stack-rows", type=int, default=1)
  16. parser.add_argument("--chains", type=int, default=1)
  17. parser.add_argument("--threads-per-chain", type=int, default=4)
  18. parser.add_argument("--samples", type=int, default=500)
  19. parser.add_argument("--warmup", type=int, default=300)
  20. parser.add_argument("--model", default="control_nu")
  21. parser.add_argument("--mle", action="store_true", default=False)
  22. args = parser.parse_args()
  23. portfolios = opj(args.input, "aggregate.csv") if args.portfolios is None else args.portfolios
  24. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  25. df = pd.read_csv(portfolios)
  26. df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= args.min_pop]
  27. if args.n < len(df):
  28. df = df.head(n=args.n)
  29. n_authors = len(df)
  30. print(opj(args.input, f"ei_samples_{args.model}_{basename(portfolios)}.npz"))
  31. resources = pd.read_parquet(
  32. opj(args.input, "pooled_resources.parquet") if args.resources is None else args.resources
  33. )
  34. print(len(resources))
  35. df = df.merge(resources, left_on="bai", right_on="bai")
  36. assert len(df)==n_authors
  37. print(n_authors)
  38. data = {
  39. "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
  40. "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
  41. "expertise": np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values),
  42. "R": n_topics,
  43. "C": n_topics,
  44. "n_units": len(df),
  45. "threads": args.threads_per_chain
  46. }
  47. data["cov"] = np.stack(df["pooled_resources"])
  48. junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
  49. for col in ["NR", "NC", "cov", "expertise"]:
  50. data[col] = data[col][:, ~junk]
  51. data["R"] -= junk.sum()
  52. data["C"] -= junk.sum()
  53. data["cov"] = data["cov"]# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  54. fig, ax = plt.subplots()
  55. sns.heatmap(
  56. np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu", ax=ax
  57. )
  58. plt.show()
  59. fig, ax = plt.subplots()
  60. for i in range(data["R"]):
  61. ax.scatter(data["cov"][:,i], data["NR"][:,i]/data["NR"].sum(axis=1))
  62. plt.show()
  63. expertise = data["expertise"]
  64. data["nu"] = np.array([
  65. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(data["R"])]
  66. for i in range(data["R"])
  67. ])
  68. model = CmdStanModel(
  69. stan_file=f"code/ei_cov_softmax_{args.model}.stan",
  70. cpp_options={"STAN_THREADS": "TRUE"},
  71. )
  72. if args.mle:
  73. fit = model.optimize(
  74. data=data,
  75. show_console=True,
  76. iter=5000,
  77. )
  78. else:
  79. fit = model.sample(
  80. data=data,
  81. chains=args.chains,
  82. threads_per_chain=args.threads_per_chain,
  83. iter_sampling=args.samples,
  84. iter_warmup=args.warmup,
  85. show_console=True,
  86. max_treedepth=11,
  87. # step_size=0.1
  88. )
  89. vars = fit.stan_variables()
  90. samples = {}
  91. for (k, v) in vars.items():
  92. samples[k] = v
  93. format = "mle" if args.mle else "samples"
  94. if args.portfolios is None:
  95. np.savez_compressed(opj(args.input, f"ei_{format}_{args.model}.npz"), **samples)
  96. else:
  97. np.savez_compressed(opj(args.input, f"ei_{format}_{args.model}_{basename(portfolios)}.npz"), **samples)