corpus_bias.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. import pandas as pd
  12. import pickle
  13. from pyannote.core import Annotation, Segment, Timeline
  14. import stan
  15. parser = argparse.ArgumentParser(
  16. description="main model described throughout the notes."
  17. )
  18. # parser.add_argument("--group", default="child", choices=["corpus", "child"])
  19. parser.add_argument("--apply-bias-from", type=str, default="")
  20. parser.add_argument("--chains", default=4, type=int)
  21. parser.add_argument("--samples", default=2000, type=int)
  22. parser.add_argument("--validation", default=0, type=float)
  23. parser.add_argument("--simulated-children", default=40, type=int)
  24. parser.add_argument("--output", default="corpus_bias")
  25. args = parser.parse_args()
  26. def extrude(self, removed, mode: str = "intersection"):
  27. if isinstance(removed, Segment):
  28. removed = Timeline([removed])
  29. truncating_support = removed.gaps(support=self.extent())
  30. # loose for truncate means strict for crop and vice-versa
  31. if mode == "loose":
  32. mode = "strict"
  33. elif mode == "strict":
  34. mode = "loose"
  35. return self.crop(truncating_support, mode=mode)
  36. def compute_counts(parameters):
  37. corpus = parameters["corpus"]
  38. annotator = parameters["annotator"]
  39. speakers = ["CHI", "OCH", "FEM", "MAL"]
  40. project = ChildProject(parameters["path"])
  41. am = AnnotationManager(project)
  42. am.read()
  43. intersection = AnnotationManager.intersection(am.annotations, ["vtc", annotator])
  44. intersection["path"] = intersection.apply(
  45. lambda r: opj(
  46. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  47. ),
  48. axis=1,
  49. )
  50. datalad.api.get(list(intersection["path"].unique()))
  51. intersection = intersection.merge(
  52. project.recordings[["recording_filename", "child_id"]], how="left"
  53. )
  54. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  55. intersection["duration"] = (
  56. intersection["range_offset"] - intersection["range_onset"]
  57. )
  58. print(corpus, annotator, (intersection["duration"] / 1000 / 2).sum() / 3600)
  59. data = []
  60. for child, ann in intersection.groupby("child"):
  61. # print(corpus, child)
  62. segments = am.get_collapsed_segments(ann)
  63. if "speaker_type" not in segments.columns:
  64. continue
  65. segments = segments[segments["speaker_type"].isin(speakers)]
  66. vtc = {
  67. speaker: segments_to_annotation(
  68. segments[
  69. (segments["set"] == "vtc") & (segments["speaker_type"] == speaker)
  70. ],
  71. "speaker_type",
  72. ).get_timeline()
  73. for speaker in speakers
  74. }
  75. truth = {
  76. speaker: segments_to_annotation(
  77. segments[
  78. (segments["set"] == annotator)
  79. & (segments["speaker_type"] == speaker)
  80. ],
  81. "speaker_type",
  82. ).get_timeline()
  83. for speaker in speakers
  84. }
  85. for speaker_A in speakers:
  86. vtc[f"{speaker_A}_vocs_explained"] = vtc[speaker_A].crop(
  87. truth[speaker_A], mode="loose"
  88. )
  89. vtc[f"{speaker_A}_vocs_fp"] = extrude(
  90. vtc[speaker_A], vtc[f"{speaker_A}_vocs_explained"]
  91. )
  92. vtc[f"{speaker_A}_vocs_fn"] = extrude(
  93. truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode="loose")
  94. )
  95. for speaker_B in speakers:
  96. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = vtc[
  97. f"{speaker_A}_vocs_fp"
  98. ].crop(truth[speaker_B], mode="loose")
  99. for speaker_C in speakers:
  100. if speaker_C != speaker_B and speaker_C != speaker_A:
  101. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = extrude(
  102. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"],
  103. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"].crop(
  104. truth[speaker_C], mode="loose"
  105. ),
  106. )
  107. d = {}
  108. keep_child = True
  109. for i, speaker_A in enumerate(speakers):
  110. for j, speaker_B in enumerate(speakers):
  111. if i != j:
  112. z = len(vtc[f"{speaker_A}_vocs_fp_{speaker_B}"])
  113. else:
  114. z = min(
  115. len(vtc[f"{speaker_A}_vocs_explained"]), len(truth[speaker_A])
  116. )
  117. d[f"vtc_{i}_{j}"] = z
  118. if z > len(truth[speaker_B]):
  119. keep_child = False
  120. d[f"truth_{i}"] = len(truth[speaker_A])
  121. d["child"] = child
  122. d["duration"] = ann["duration"].sum() / 2 / 1000
  123. if keep_child:
  124. data.append(d)
  125. return pd.DataFrame(data).assign(
  126. corpus=corpus,
  127. )
  128. stan_code = """
  129. data {
  130. int<lower=1> n_clips; // number of clips
  131. int<lower=1> n_groups; // number of groups
  132. int<lower=1> n_corpora;
  133. int<lower=1> n_classes; // number of classes
  134. int group[n_clips];
  135. int corpus[n_clips];
  136. int vtc[n_clips,n_classes,n_classes];
  137. int truth[n_clips,n_classes];
  138. int<lower=1> n_validation;
  139. int<lower=1> n_sim;
  140. int<lower=0> selected_corpus;
  141. real<lower=0> rates_alphas[n_classes];
  142. real<lower=0> rates_betas[n_classes];
  143. }
  144. parameters {
  145. matrix<lower=0,upper=1>[n_classes,n_classes] mus;
  146. matrix<lower=1>[n_classes,n_classes] etas;
  147. matrix<lower=0,upper=1>[n_classes,n_classes] group_confusion[n_groups];
  148. matrix[n_classes,n_classes] corpus_bias[n_corpora];
  149. matrix<lower=0>[n_classes,n_classes] corpus_sigma;
  150. }
  151. transformed parameters {
  152. matrix<lower=0>[n_classes,n_classes] alphas;
  153. matrix<lower=0>[n_classes,n_classes] betas;
  154. alphas = mus * etas;
  155. betas = (1-mus) * etas;
  156. }
  157. model {
  158. for (k in n_validation:n_clips) {
  159. for (i in 1:n_classes) {
  160. for (j in 1:n_classes) {
  161. vtc[k,i,j] ~ binomial(
  162. truth[k,j], inv_logit(logit(group_confusion[group[k],j,i]) + corpus_bias[corpus[k],j,i])
  163. );
  164. }
  165. }
  166. }
  167. for (i in 1:n_classes) {
  168. for (j in 1:n_classes) {
  169. mus[i,j] ~ beta(1,1);
  170. etas[i,j] ~ pareto(1,1.5);
  171. }
  172. }
  173. for (c in 1:n_groups) {
  174. for (i in 1:n_classes) {
  175. for (j in 1:n_classes) {
  176. group_confusion[c,i,j] ~ beta(alphas[i,j], betas[i,j]);
  177. }
  178. }
  179. }
  180. for (i in 1:n_classes) {
  181. for (j in 1:n_classes) {
  182. for (c in 1:n_corpora) {
  183. corpus_bias[c,j,i] ~ normal(0, corpus_sigma[j,i]);
  184. }
  185. corpus_sigma[j,i] ~ normal(0, 1);
  186. }
  187. }
  188. }
  189. generated quantities {
  190. int pred[n_clips,n_classes,n_classes];
  191. matrix[n_classes,n_classes] probs[n_groups];
  192. matrix[n_classes,n_classes] log_lik[n_clips];
  193. matrix[n_classes,n_classes] random_bias;
  194. matrix[n_classes,n_classes] fixed_bias;
  195. int sim_truth[n_sim,n_classes];
  196. int sim_vtc[n_sim,n_classes];
  197. vector[n_classes] lambdas;
  198. real chi_adu_coef = 0; // null-hypothesis
  199. for (i in 1:n_classes) {
  200. for (j in 1:n_classes) {
  201. if (selected_corpus != 0) {
  202. fixed_bias[j, i] = corpus_bias[selected_corpus, j, i];
  203. }
  204. else {
  205. fixed_bias[j, i] = 0;
  206. }
  207. random_bias[j,i] = normal_rng(0, corpus_sigma[j,i]);
  208. }
  209. }
  210. for (c in 1:n_groups) {
  211. for (i in 1:n_classes) {
  212. for (j in 1:n_classes) {
  213. probs[c,i,j] = beta_rng(alphas[i,j], betas[i,j]);
  214. }
  215. }
  216. }
  217. for (k in 1:n_clips) {
  218. for (i in 1:n_classes) {
  219. for (j in 1:n_classes) {
  220. if (k >= n_validation) {
  221. pred[k,i,j] = binomial_rng(truth[k,j], inv_logit(logit(group_confusion[group[k],j,i]) + corpus_bias[corpus[k], j,i]));
  222. log_lik[k,i,j] = binomial_lpmf(
  223. vtc[k,i,j] | truth[k,j], inv_logit(logit(group_confusion[group[k],j,i]) + corpus_bias[corpus[k], j,i])
  224. );
  225. }
  226. else {
  227. pred[k,i,j] = binomial_rng(
  228. truth[k,j], inv_logit(logit(probs[group[k],j,i]) + corpus_bias[corpus[k], j,i])
  229. );
  230. log_lik[k,i,j] = beta_lpdf(probs[group[k],j,i] | alphas[j,i], betas[j,i]);
  231. log_lik[k,i,j] += binomial_lpmf(
  232. vtc[k,i,j] | truth[k,j], inv_logit(logit(probs[group[k],j,i]) + corpus_bias[corpus[k], j,i])
  233. );
  234. }
  235. }
  236. }
  237. }
  238. real lambda;
  239. for (k in 1:n_sim) {
  240. for (i in 2:n_classes) {
  241. lambda = gamma_rng(rates_alphas[i], rates_betas[i]);
  242. sim_truth[k,i] = poisson_rng(lambda);
  243. }
  244. lambda = gamma_rng(rates_alphas[1], rates_betas[1]);
  245. sim_truth[k,1] = poisson_rng(lambda + chi_adu_coef*(sim_truth[k,3]+sim_truth[k,4]));
  246. }
  247. for (k in 1:n_sim) {
  248. for (i in 1:n_classes) {
  249. sim_vtc[k,i] = 0;
  250. for (j in 1:n_classes) {
  251. real p = logit(beta_rng(alphas[j,i], betas[j,i]));
  252. if (selected_corpus != 0) {
  253. p += fixed_bias[j,i];
  254. }
  255. else {
  256. p += random_bias[j,i];
  257. }
  258. p = inv_logit(p);
  259. sim_vtc[k,i] += binomial_rng(sim_truth[k,j], p);
  260. }
  261. }
  262. }
  263. }
  264. """
  265. if __name__ == "__main__":
  266. annotators = pd.read_csv("input/annotators.csv")
  267. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  268. with mp.Pool(processes=8) as pool:
  269. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  270. data = data.sample(frac=1)
  271. duration = data["duration"].sum()
  272. vtc = np.moveaxis(
  273. [[data[f"vtc_{j}_{i}"].values for i in range(4)] for j in range(4)], -1, 0
  274. )
  275. truth = np.transpose([data[f"truth_{i}"].values for i in range(4)])
  276. print(vtc.shape)
  277. rates = pd.read_csv("output/speech_dist.csv")
  278. training_set = data.groupby("corpus").agg(
  279. duration=("duration", "sum"), children=("child", lambda x: x.nunique())
  280. )
  281. training_set["duration"] /= 3600
  282. training_set.to_csv("output/training_set.csv")
  283. data["corpus"] = data["corpus"].astype("category")
  284. corpora = data["corpus"].cat.codes.values
  285. corpora_codes = dict(enumerate(data["corpus"].cat.categories))
  286. corpora_codes = {v: k for k, v in corpora_codes.items()}
  287. data = {
  288. "n_clips": truth.shape[0],
  289. "n_classes": truth.shape[1],
  290. "n_groups": data["child"].nunique(),
  291. "n_corpora": data["corpus"].nunique(),
  292. "n_validation": max(1, int(truth.shape[0] * args.validation)),
  293. "n_sim": args.simulated_children,
  294. "group": 1 + data["child"].astype("category").cat.codes.values,
  295. "corpus": 1 + corpora,
  296. "selected_corpus": (
  297. 1 + corpora_codes[args.apply_bias_from]
  298. if args.apply_bias_from in corpora_codes
  299. else 0
  300. ),
  301. "truth": truth.astype(int),
  302. "vtc": vtc.astype(int),
  303. "rates_alphas": rates["alpha"].values,
  304. "rates_betas": rates["beta"].values,
  305. }
  306. print(f"clips: {data['n_clips']}")
  307. print(f"groups: {data['n_groups']}")
  308. print("true vocs: {}".format(np.sum(data["truth"])))
  309. print("vtc vocs: {}".format(np.sum(data["vtc"])))
  310. print("duration: {}".format(duration))
  311. print("selected corpus: {}".format(data["selected_corpus"]))
  312. with open(f"output/samples/data_{args.output}.pickle", "wb") as fp:
  313. pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL)
  314. posterior = stan.build(stan_code, data=data)
  315. fit = posterior.sample(num_chains=args.chains, num_samples=args.samples)
  316. df = fit.to_frame()
  317. df.to_parquet(f"output/samples/fit_{args.output}.parquet")