simple.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #!/usr/bin/env python3
  2. from ChildProject.projects import ChildProject
  3. from ChildProject.annotations import AnnotationManager
  4. from ChildProject.metrics import segments_to_annotation
  5. import argparse
  6. import datalad.api
  7. from os.path import join as opj
  8. from os.path import basename, exists
  9. import multiprocessing as mp
  10. import numpy as np
  11. import pandas as pd
  12. import pickle
  13. from pyannote.core import Annotation, Segment, Timeline
  14. import stan
  15. parser = argparse.ArgumentParser(
  16. description="main model described throughout the notes."
  17. )
  18. parser.add_argument("--group", default="child", choices=["corpus", "child"])
  19. parser.add_argument("--chains", default=4, type=int)
  20. parser.add_argument("--samples", default=2000, type=int)
  21. parser.add_argument("--validation", default=0, type=float)
  22. parser.add_argument("--output", default="model3")
  23. args = parser.parse_args()
  24. def extrude(self, removed, mode: str = "intersection"):
  25. if isinstance(removed, Segment):
  26. removed = Timeline([removed])
  27. truncating_support = removed.gaps(support=self.extent())
  28. # loose for truncate means strict for crop and vice-versa
  29. if mode == "loose":
  30. mode = "strict"
  31. elif mode == "strict":
  32. mode = "loose"
  33. return self.crop(truncating_support, mode=mode)
  34. def compute_counts(parameters):
  35. corpus = parameters["corpus"]
  36. annotator = parameters["annotator"]
  37. speakers = ["CHI", "OCH", "FEM", "MAL"]
  38. project = ChildProject(parameters["path"])
  39. am = AnnotationManager(project)
  40. am.read()
  41. intersection = AnnotationManager.intersection(am.annotations, ["vtc", annotator])
  42. intersection["path"] = intersection.apply(
  43. lambda r: opj(
  44. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  45. ),
  46. axis=1,
  47. )
  48. datalad.api.get(list(intersection["path"].unique()))
  49. intersection = intersection.merge(
  50. project.recordings[["recording_filename", "child_id"]], how="left"
  51. )
  52. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  53. intersection["duration"] = (
  54. intersection["range_offset"] - intersection["range_onset"]
  55. )
  56. print(corpus, annotator, (intersection["duration"] / 1000 / 2).sum() / 3600)
  57. data = []
  58. for child, ann in intersection.groupby("child"):
  59. # print(corpus, child)
  60. segments = am.get_collapsed_segments(ann)
  61. if "speaker_type" not in segments.columns:
  62. continue
  63. segments = segments[segments["speaker_type"].isin(speakers)]
  64. vtc = {
  65. speaker: segments_to_annotation(
  66. segments[
  67. (segments["set"] == "vtc") & (segments["speaker_type"] == speaker)
  68. ],
  69. "speaker_type",
  70. ).get_timeline()
  71. for speaker in speakers
  72. }
  73. truth = {
  74. speaker: segments_to_annotation(
  75. segments[
  76. (segments["set"] == annotator)
  77. & (segments["speaker_type"] == speaker)
  78. ],
  79. "speaker_type",
  80. ).get_timeline()
  81. for speaker in speakers
  82. }
  83. for speaker_A in speakers:
  84. vtc[f"{speaker_A}_vocs_explained"] = vtc[speaker_A].crop(
  85. truth[speaker_A], mode="loose"
  86. )
  87. vtc[f"{speaker_A}_vocs_fp"] = extrude(
  88. vtc[speaker_A], vtc[f"{speaker_A}_vocs_explained"]
  89. )
  90. vtc[f"{speaker_A}_vocs_fn"] = extrude(
  91. truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode="loose")
  92. )
  93. for speaker_B in speakers:
  94. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = vtc[
  95. f"{speaker_A}_vocs_fp"
  96. ].crop(truth[speaker_B], mode="loose")
  97. for speaker_C in speakers:
  98. if speaker_C != speaker_B and speaker_C != speaker_A:
  99. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = extrude(
  100. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"],
  101. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"].crop(
  102. truth[speaker_C], mode="loose"
  103. ),
  104. )
  105. d = {}
  106. for i, speaker_A in enumerate(speakers):
  107. for j, speaker_B in enumerate(speakers):
  108. if i != j:
  109. z = len(vtc[f"{speaker_A}_vocs_fp_{speaker_B}"])
  110. else:
  111. z = min(
  112. len(vtc[f"{speaker_A}_vocs_explained"]), len(truth[speaker_A])
  113. )
  114. d[f"vtc_{i}_{j}"] = z
  115. d[f"truth_{i}"] = len(truth[speaker_A])
  116. d["child"] = child
  117. d["duration"] = ann["duration"].sum() / 2 / 1000
  118. data.append(d)
  119. return pd.DataFrame(data).assign(
  120. corpus=corpus,
  121. )
  122. stan_code = """
  123. data {
  124. int<lower=1> n_clips; // number of clips
  125. int<lower=1> n_groups; // number of groups
  126. int<lower=1> n_classes; // number of classes
  127. int group[n_clips];
  128. int vtc[n_clips,n_classes,n_classes];
  129. int truth[n_clips,n_classes];
  130. int<lower=1> n_validation;
  131. int<lower=1> n_sim;
  132. real<lower=0> rates_alphas[n_classes];
  133. real<lower=0> rates_betas[n_classes];
  134. }
  135. parameters {
  136. matrix<lower=0,upper=1>[n_classes,n_classes] mus;
  137. matrix<lower=1>[n_classes,n_classes] etas;
  138. matrix<lower=0,upper=1>[n_classes,n_classes] group_confusion[n_groups];
  139. }
  140. transformed parameters {
  141. matrix<lower=0>[n_classes,n_classes] alphas;
  142. matrix<lower=0>[n_classes,n_classes] betas;
  143. alphas = mus * etas;
  144. betas = (1-mus) * etas;
  145. }
  146. model {
  147. for (k in n_validation:n_clips) {
  148. for (i in 1:n_classes) {
  149. for (j in 1:n_classes) {
  150. vtc[k,i,j] ~ binomial(truth[k,j], group_confusion[group[k],j,i]);
  151. }
  152. }
  153. }
  154. for (i in 1:n_classes) {
  155. for (j in 1:n_classes) {
  156. mus[i,j] ~ beta(1,1);
  157. etas[i,j] ~ pareto(1,1.5);
  158. }
  159. }
  160. for (c in 1:n_groups) {
  161. for (i in 1:n_classes) {
  162. for (j in 1:n_classes) {
  163. group_confusion[c,i,j] ~ beta(alphas[i,j], betas[i,j]);
  164. }
  165. }
  166. }
  167. }
  168. generated quantities {
  169. int pred[n_clips,n_classes,n_classes];
  170. matrix[n_classes,n_classes] probs[n_groups];
  171. matrix[n_classes,n_classes] log_lik[n_clips];
  172. int sim_truth[n_sim,n_classes];
  173. int sim_vtc[n_sim,n_classes];
  174. vector[n_classes] lambdas;
  175. real chi_adu_coef = 0; // null-hypothesis
  176. for (c in 1:n_groups) {
  177. for (i in 1:n_classes) {
  178. for (j in 1:n_classes) {
  179. probs[c,i,j] = beta_rng(alphas[i,j], betas[i,j]);
  180. }
  181. }
  182. }
  183. for (k in 1:n_clips) {
  184. for (i in 1:n_classes) {
  185. for (j in 1:n_classes) {
  186. if (k >= n_validation) {
  187. pred[k,i,j] = binomial_rng(truth[k,j], group_confusion[group[k],i,j]);
  188. log_lik[k,i,j] = binomial_lpmf(vtc[k,i,j] | truth[k,j], group_confusion[group[k],j,i]);
  189. }
  190. else {
  191. pred[k,i,j] = binomial_rng(truth[k,j], probs[group[k],j,i]);
  192. log_lik[k,i,j] = beta_lpdf(probs[group[k],j,i] | alphas[j,i], betas[j,i]);
  193. log_lik[k,i,j] += binomial_lpmf(vtc[k,i,j] | truth[k,j], probs[group[k],j,i]);
  194. }
  195. }
  196. }
  197. }
  198. real lambda;
  199. for (k in 1:n_sim) {
  200. for (i in 2:n_classes) {
  201. lambda = gamma_rng(rates_alphas[i], rates_betas[i]);
  202. sim_truth[k,i] = poisson_rng(lambda);
  203. }
  204. lambda = gamma_rng(rates_alphas[1], rates_betas[1]);
  205. sim_truth[k,1] = poisson_rng(lambda + chi_adu_coef*(sim_truth[k,3]+sim_truth[k,4]));
  206. }
  207. for (k in 1:n_sim) {
  208. for (i in 1:n_classes) {
  209. sim_vtc[k,i] = 0;
  210. for (j in 1:n_classes) {
  211. real p = beta_rng(alphas[j,i], betas[j,i]);
  212. sim_vtc[k,i] += binomial_rng(sim_truth[k,j], p);
  213. }
  214. }
  215. }
  216. }
  217. """
  218. if __name__ == "__main__":
  219. annotators = pd.read_csv("input/annotators.csv")
  220. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  221. with mp.Pool(processes=8) as pool:
  222. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  223. data = data.sample(frac=1)
  224. duration = data["duration"].sum()
  225. vtc = np.moveaxis(
  226. [[data[f"vtc_{j}_{i}"].values for i in range(4)] for j in range(4)], -1, 0
  227. )
  228. truth = np.transpose([data[f"truth_{i}"].values for i in range(4)])
  229. print(vtc.shape)
  230. rates = pd.read_csv("output/speech_dist.csv")
  231. training_set = data.groupby("corpus").agg(
  232. duration=("duration", "sum"), children=("child", lambda x: x.nunique())
  233. )
  234. training_set["duration"] /= 3600
  235. training_set.to_csv("output/training_set.csv")
  236. data = {
  237. "n_clips": truth.shape[0],
  238. "n_classes": truth.shape[1],
  239. "n_groups": data[args.group].nunique(),
  240. "n_validation": max(1, int(truth.shape[0] * args.validation)),
  241. "n_sim": 40,
  242. "group": 1 + data[args.group].astype("category").cat.codes.values,
  243. "truth": truth.astype(int),
  244. "vtc": vtc.astype(int),
  245. "rates_alphas": rates["alpha"].values,
  246. "rates_betas": rates["beta"].values,
  247. }
  248. print(f"clips: {data['n_clips']}")
  249. print(f"groups: {data['n_groups']}")
  250. print("true vocs: {}".format(np.sum(data["truth"])))
  251. print("vtc vocs: {}".format(np.sum(data["vtc"])))
  252. print("duration: {}".format(duration))
  253. with open(f"output/samples/data_{args.output}.pickle", "wb") as fp:
  254. pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL)
  255. posterior = stan.build(stan_code, data=data)
  256. fit = posterior.sample(num_chains=args.chains, num_samples=args.samples)
  257. df = fit.to_frame()
  258. df.to_parquet(f"output/samples/fit_{args.output}.parquet")