ei_map.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from socketserver import ThreadingUnixStreamServer
  2. import pandas as pd
  3. import numpy as np
  4. import torch
  5. from sklearn.model_selection import train_test_split, KFold
  6. from scipy.special import softmax
  7. import argparse
  8. import pickle
  9. from os.path import join as opj, basename
  10. import seaborn as sns
  11. from matplotlib import pyplot as plt
  12. import matplotlib
  13. matplotlib.use("pgf")
  14. matplotlib.rcParams.update(
  15. {
  16. "pgf.texsystem": "xelatex",
  17. "font.family": "serif",
  18. "font.serif": "Times New Roman",
  19. "text.usetex": True,
  20. "pgf.rcfonts": False,
  21. }
  22. )
  23. plt.rcParams["text.latex.preamble"].join([
  24. r"\usepackage{amsmath}",
  25. r"\setmainfont{amssymb}",
  26. ])
  27. from multiprocessing import Pool
  28. from functools import partial
  29. class TrajectoryModel(torch.nn.Module):
  30. def __init__(self, N, R, C, nu):
  31. super().__init__()
  32. self.N = N
  33. self.R = R
  34. self.C = C
  35. self.nu = nu
  36. self.dtype = torch.float
  37. self.init_weights()
  38. torch.autograd.set_detect_anomaly(True)
  39. if torch.cuda.is_available():
  40. self.device = torch.device("cuda")
  41. self.to(self.device)
  42. else:
  43. print("GPU is not available, using CPU instead")
  44. def init_weights(self):
  45. self.beta = torch.nn.Parameter(
  46. torch.zeros((self.N, self.R, self.C))
  47. )
  48. self.mu = torch.nn.Parameter(torch.zeros((self.R, self.C-1)))
  49. self.gamma = torch.nn.Parameter(torch.zeros((self.R, self.C)))
  50. self.delta = torch.nn.Parameter(torch.zeros((self.R, self.C)))
  51. self.eps = torch.nn.Parameter(
  52. torch.zeros((self.N, self.R, self.C-1))
  53. )
  54. self.sigma = torch.nn.Parameter(torch.zeros((R,C-1)))
  55. self.mu_nu = torch.nn.Parameter(torch.zeros(1))
  56. def train(self, train, validation, epoch=50, autostop=False, printouts=1000):
  57. optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
  58. # clipping_value = 1e6
  59. epochs = []
  60. train_loss = []
  61. validation_loss = []
  62. reference_loss = []
  63. for t in range(epoch):
  64. # print("mu:", self.mu)
  65. loss = self.loss(train)
  66. print(t, loss.item())
  67. optimizer.zero_grad()
  68. loss.backward()
  69. optimizer.step()
  70. if (t % 2) == 0:
  71. # beta = (torch.einsum("ld,lij->dij", self.M, self.lambd)).detach().numpy()
  72. y_pred_train = self.predict(train).detach().numpy()
  73. y_pred_validation = self.predict(validation).detach().numpy()
  74. epochs.append(t)
  75. train_loss.append((np.abs(train["y"]-y_pred_train).sum(axis=1)/2).mean())
  76. validation_loss.append((np.abs(validation["y"]-y_pred_validation).sum(axis=1)/2).mean())
  77. reference_loss.append((np.abs(validation["y"]-validation["x"]).sum(axis=1)/2).mean())
  78. if (t % printouts) == 0:
  79. fig, ax = plt.subplots(figsize=[6.4*0.75, 4.8*0.75])
  80. ax.plot(np.array(epochs), train_loss, label="$d_{\\mathrm{TV}}(\\vec{y}_a,\\vec{y}_a^{\\mathrm{pred}})$ -- training set")
  81. ax.plot(np.array(epochs), validation_loss, label="$d_{\\mathrm{TV}}(\\vec{y}_a,\\vec{y}_a^{\\mathrm{pred}})$ -- test set")
  82. ax.plot(np.array(epochs), reference_loss, label="$d_{\\mathrm{TV}}(\\vec{y}_a,\\vec{x}_a)$ -- test set")
  83. ax.set_xlabel("Epochs")
  84. ax.set_ylabel("Performance (total variation distance)")
  85. ax.legend(loc='upper right', bbox_to_anchor=(1, 1.2))
  86. ax.set_ylim(0.3,0.6)
  87. fig.savefig(f"status_{basename(args.input)}.eps", bbox_inches="tight")
  88. if autostop and len(validation_loss)>2 and validation_loss[-1]>validation_loss[-2] and validation_loss[-2]>validation_loss[-3]:
  89. break
  90. return train_loss, validation_loss, reference_loss
  91. def predict(self, data, eps=None):
  92. N = data["N"]
  93. mu = torch.zeros((self.R, self.C))
  94. mu[:,:-1] = self.mu
  95. mu += self.nu*self.mu_nu
  96. s = torch.zeros((N, self.R, self.C))
  97. s = s+torch.einsum("ij,aj->aij", self.gamma, data["expertise"])
  98. s = s+torch.einsum("ij,aj->aij", self.delta, data["cov"])
  99. s = s+mu
  100. if eps is not None:
  101. eps_ = torch.zeros((N, self.R, self.C))
  102. eps_[:,:,:-1] = eps
  103. s += eps_
  104. b = torch.softmax(s, dim=2)
  105. p = torch.einsum("aij,ai->aj", b, data["x"])
  106. return p
  107. def loss(self, data):
  108. Y = data["Y"]
  109. loss = 0
  110. p = self.predict(data, self.eps)
  111. for a in range(p.shape[0]):
  112. multi = torch.distributions.multinomial.Multinomial(
  113. total_count=Y[a,:].sum().max().item(),
  114. probs=p[a,:]
  115. )
  116. loss -= multi.log_prob(Y[a,:]).sum()
  117. print("evidence loss: ", loss/data["N"])
  118. eps_prior = torch.distributions.normal.Normal(0, self.sigma.exp())
  119. sigma_prior = torch.distributions.exponential.Exponential(1)
  120. normal_prior = torch.distributions.normal.Normal(0, 1)
  121. priors_loss = 0
  122. priors_loss -= eps_prior.log_prob(self.eps).sum()
  123. priors_loss -= sigma_prior.log_prob(self.sigma.exp()).sum()
  124. priors_loss -= normal_prior.log_prob(self.mu).sum()
  125. priors_loss -= normal_prior.log_prob(self.delta).sum()
  126. priors_loss -= normal_prior.log_prob(self.gamma).sum()
  127. priors_loss -= normal_prior.log_prob(self.mu_nu).sum()
  128. print("priors loss:", priors_loss/data["N"])
  129. loss += priors_loss
  130. loss /= data["N"]
  131. return loss
  132. parser = argparse.ArgumentParser()
  133. parser.add_argument("--input")
  134. parser.add_argument("--folds", default=0, type=int)
  135. args = parser.parse_args()
  136. n_topics = len(pd.read_csv(opj(args.input, "topics.csv")))
  137. df = pd.read_csv(opj(args.input, "aggregate.csv"))
  138. df = df[df[[f"start_{k+1}" for k in range(n_topics)]].sum(axis=1) >= 100]
  139. resources = pd.read_parquet(opj(args.input, "pooled_resources.parquet"))
  140. df = df.merge(resources, left_on="bai", right_on="bai")
  141. data = {
  142. "NR": np.stack(df[[f"start_{k+1}" for k in range(n_topics)]].values).astype(int),
  143. "NC": np.stack(df[[f"end_{k+1}" for k in range(n_topics)]].values).astype(int),
  144. "expertise": np.stack(df[[f"expertise_{k+1}" for k in range(n_topics)]].fillna(0).values),
  145. }
  146. data["cov"] = np.stack(df["pooled_resources"])
  147. junk = np.sum(data["NR"] + data["NC"], axis=0) == 0
  148. for col in ["NR", "NC", "cov", "expertise"]:
  149. data[col] = data[col][:, ~junk]
  150. R = n_topics-junk.sum()
  151. C = n_topics-junk.sum()
  152. data["cov"] = np.nan_to_num(data["cov"])# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  153. data["expertise"] = np.nan_to_num(data["expertise"])# / np.maximum(data["cov"].sum(axis=1)[:, np.newaxis], 1)
  154. expertise = data["expertise"]
  155. nu = np.array([
  156. [((expertise[:,i]>expertise[:,i].mean())&(expertise[:,j]>expertise[:,j].mean())).mean()/(expertise[:,i]>expertise[:,i].mean()).mean() for j in range(R)]
  157. for i in range(R)
  158. ])
  159. data["Y"] = data["NC"]
  160. data["x"] = data["NR"]/data["NR"].sum(axis=1)[:,np.newaxis]
  161. data["y"] = data["NC"]/data["NC"].sum(axis=1)[:,np.newaxis]
  162. N = data["x"].shape[0]
  163. def split_train_validation(data, test_size):
  164. train_ind, test_ind = train_test_split(np.arange(N), test_size=test_size)
  165. train, validation = {}, {}
  166. for k in data:
  167. train[k] = torch.from_numpy(data[k][train_ind])
  168. validation[k] = torch.from_numpy(data[k][test_ind])
  169. return train, validation
  170. def folds(data, folds):
  171. f = []
  172. kf = KFold(n_splits=folds, shuffle=True)
  173. for i, (train_ind, test_ind) in enumerate(kf.split(np.arange(N))):
  174. fold_train, fold_test = {}, {}
  175. for k in data:
  176. fold_train[k] = torch.from_numpy(data[k][train_ind])
  177. fold_test[k] = torch.from_numpy(data[k][test_ind])
  178. f.append((fold_train, fold_test))
  179. return f
  180. def run_model(data):
  181. data[0]["N"] = data[0]["x"].shape[0]
  182. data[1]["N"] = data[1]["x"].shape[0]
  183. mdl = TrajectoryModel(data[0]["N"], R, C, torch.from_numpy(nu))
  184. train_loss, validation_loss, reference_loss = mdl.train(
  185. data[0], data[1],
  186. epoch=1000,
  187. autostop=True,
  188. printouts=50
  189. )
  190. scores = [
  191. train_loss[-1].detach().numpy(), validation_loss[-1].detach().numpy(), reference_loss[-1].detach().numpy()
  192. ]
  193. print(scores)
  194. return scores
  195. if args.folds > 0:
  196. f = folds(data, args.folds)
  197. with Pool(processes=args.folds) as pool:
  198. scores = pool.map(run_model, f)
  199. scores = np.array(scores)
  200. np.save(opj(args.input, f"scores.npy"), scores)
  201. else:
  202. train, validation = split_train_validation(data, test_size=0.2)
  203. train["N"] = train["x"].shape[0]
  204. validation["N"] = validation["x"].shape[0]
  205. mdl = TrajectoryModel(train["N"], R, C, torch.from_numpy(nu))
  206. mdl.train(train, validation, epoch=800)