Ver Fonte

Mise à jour de 'code/etm_ei.py'

Lucas Gautheron há 2 meses atrás
pai
commit
22026717fe
1 ficheiros alterados com 30 adições e 11 exclusões
  1. 30 11
      code/etm_ei.py

+ 30 - 11
code/etm_ei.py

@@ -1,4 +1,3 @@
-from socketserver import ThreadingUnixStreamServer
 import pandas as pd
 import numpy as np
 
@@ -13,15 +12,15 @@ import seaborn as sns
 from matplotlib import pyplot as plt
 
 parser = argparse.ArgumentParser()
-parser.add_argument("--input")
+parser.add_argument("--input", default="output/etm_20_pretrained")
 parser.add_argument("--n", type=int, default=200)
 parser.add_argument("--min-pop", type=int, default=100)
 parser.add_argument("--stack-rows", type=int, default=1)
-
 parser.add_argument("--chains", type=int, default=1)
 parser.add_argument("--threads-per-chain", type=int, default=4)
 parser.add_argument("--samples", type=int, default=500)
-parser.add_argument("--warmup", type=int, default=1000)
+parser.add_argument("--warmup", type=int, default=300)
+parser.add_argument("--model", default="control")
 
 args = parser.parse_args()
 
@@ -29,7 +28,11 @@ n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
 
 df = pd.read_csv(opj(args.input, "aggregate.csv"))
 df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= args.min_pop]
-df = df.sample(n=args.n)
+
+if args.n < len(df):
+    df = df.head(n=args.n)
+
+print(df)
 
 resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
 df = df.merge(resources, left_on="bai", right_on="bai")
@@ -37,6 +40,7 @@ df = df.merge(resources, left_on="bai", right_on="bai")
 data = {
     "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
     "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
+    "expertise": np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].values),
     "R": n_topics,
     "C": n_topics,
     "n_units": len(df),
@@ -47,24 +51,36 @@ data["cov"] = np.stack(df["pooled_resources"])
 
 junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
 
-for col in ["NR", "NC", "cov"]:
+for col in ["NR", "NC", "cov", "expertise"]:
     data[col] = data[col][:, ~junk]
 
 data["R"] -= junk.sum()
 data["C"] -= junk.sum()
 
-data["cov"] = data["cov"] / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
+data["cov"] = data["cov"]# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
+
+fig, ax = plt.subplots()
 sns.heatmap(
-    np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu"
+    np.corrcoef(data["NC"].T, data["cov"].T), vmin=-0.5, vmax=0.5, cmap="RdBu", ax=ax
 )
 plt.show()
 
+fig, ax = plt.subplots()
+for i in range(data["R"]):
+    ax.scatter(data["cov"][:,i], data["NR"][:,i]/data["NR"].sum(axis=1))
+
+plt.show()
 
-print(data["cov"].shape)
+expertise = data["expertise"]
+data["nu"] = np.array([
+    [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(data["R"])]
+    for i in range(data["R"])
+])
 
+print(data)
 
 model = CmdStanModel(
-    stan_file=f"code/ei_cov_softmax_control.stan",
+    stan_file=f"code/ei_cov_softmax_{args.model}.stan",
     cpp_options={"STAN_THREADS": "TRUE"},
     compile="force",
 )
@@ -75,6 +91,9 @@ fit = model.sample(
     threads_per_chain=args.threads_per_chain,
     iter_sampling=args.samples,
     iter_warmup=args.warmup,
+    show_console=True,
+    max_treedepth=11,
+    # step_size=0.1
 )
 
 vars = fit.stan_variables()
@@ -82,4 +101,4 @@ samples = {}
 for (k, v) in vars.items():
     samples[k] = v
 
-np.savez_compressed(opj(args.input, "ei_samples.npz"), **samples)
+np.savez_compressed(opj(args.input, f"ei_samples_{args.model}.npz"), **samples)