analysis.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #!/usr/bin/env python
  2. import pandas as pd
  3. import numpy as np
  4. from ChildProject.projects import ChildProject
  5. from ChildProject.annotations import AnnotationManager
  6. from ChildProject.metrics import segments_to_grid, segments_to_annotation
  7. from ChildProject.pipelines.metrics import AclewMetrics
  8. from ChildProject.utils import TimeInterval
  9. from cmdstanpy import CmdStanModel
  10. import datetime
  11. import matplotlib
  12. from matplotlib import pyplot as plt
  13. matplotlib.use("pgf")
  14. matplotlib.rcParams.update(
  15. {
  16. "pgf.texsystem": "xelatex",
  17. "font.family": "serif",
  18. "font.serif": "Times New Roman",
  19. "text.usetex": True,
  20. "pgf.rcfonts": False,
  21. }
  22. )
  23. import pickle
  24. import datalad.api
  25. from os.path import join as opj
  26. from os.path import basename, exists
  27. import multiprocessing as mp
  28. from pyannote.core import Annotation, Segment, Timeline
  29. import argparse
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument("--corpora", default=["input/bergelson"], nargs="+")
  32. parser.add_argument("--duration", type=int, help="duration in hours", default=8)
  33. parser.add_argument("--run")
  34. parser.add_argument("--chains", type=int, default=1)
  35. parser.add_argument("--warmup", type=int, default=250)
  36. parser.add_argument("--samples", type=int, default=1000)
  37. parser.add_argument("--threads-per-chain", type=int, default=4)
  38. parser.add_argument("--models", default=["analysis"], nargs='+')
  39. parser.add_argument("--validation", type=float, default=0)
  40. args = parser.parse_args()
  41. speakers = ["CHI", "OCH", "FEM", "MAL"]
  42. def extrude(self, removed, mode: str = "intersection"):
  43. if isinstance(removed, Segment):
  44. removed = Timeline([removed])
  45. truncating_support = removed.gaps(support=self.extent())
  46. # loose for truncate means strict for crop and vice-versa
  47. if mode == "loose":
  48. mode = "strict"
  49. elif mode == "strict":
  50. mode = "loose"
  51. return self.crop(truncating_support, mode=mode)
  52. def compute_counts(parameters):
  53. corpus = parameters["corpus"]
  54. annotator = parameters["annotator"]
  55. project = ChildProject(parameters["path"])
  56. am = AnnotationManager(project)
  57. am.read()
  58. intersection = AnnotationManager.intersection(am.annotations, ["vtc", annotator])
  59. intersection["path"] = intersection.apply(
  60. lambda r: opj(
  61. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  62. ),
  63. axis=1,
  64. )
  65. datalad.api.get(list(intersection["path"].unique()))
  66. intersection = intersection.merge(
  67. project.recordings[["recording_filename", "child_id"]], how="left"
  68. )
  69. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  70. intersection["duration"] = (
  71. intersection["range_offset"] - intersection["range_onset"]
  72. )
  73. print(corpus, annotator, (intersection["duration"] / 1000 / 2).sum() / 3600)
  74. data = []
  75. for child, ann in intersection.groupby("child"):
  76. # print(corpus, child)
  77. segments = am.get_collapsed_segments(ann)
  78. if "speaker_type" not in segments.columns:
  79. continue
  80. segments = segments[segments["speaker_type"].isin(speakers)]
  81. vtc = {
  82. speaker: segments_to_annotation(
  83. segments[
  84. (segments["set"] == "vtc") & (segments["speaker_type"] == speaker)
  85. ],
  86. "speaker_type",
  87. ).get_timeline()
  88. for speaker in speakers
  89. }
  90. truth = {
  91. speaker: segments_to_annotation(
  92. segments[
  93. (segments["set"] == annotator)
  94. & (segments["speaker_type"] == speaker)
  95. ],
  96. "speaker_type",
  97. ).get_timeline()
  98. for speaker in speakers
  99. }
  100. for speaker_A in speakers:
  101. vtc[f"{speaker_A}_vocs_explained"] = vtc[speaker_A].crop(
  102. truth[speaker_A], mode="loose"
  103. )
  104. vtc[f"{speaker_A}_vocs_fp"] = extrude(
  105. vtc[speaker_A], vtc[f"{speaker_A}_vocs_explained"]
  106. )
  107. vtc[f"{speaker_A}_vocs_fn"] = extrude(
  108. truth[speaker_A], truth[speaker_A].crop(vtc[speaker_A], mode="loose")
  109. )
  110. for speaker_B in speakers:
  111. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = vtc[
  112. f"{speaker_A}_vocs_fp"
  113. ].crop(truth[speaker_B], mode="loose")
  114. for speaker_C in speakers:
  115. if speaker_C != speaker_B and speaker_C != speaker_A:
  116. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"] = extrude(
  117. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"],
  118. vtc[f"{speaker_A}_vocs_fp_{speaker_B}"].crop(
  119. truth[speaker_C], mode="loose"
  120. ),
  121. )
  122. d = {}
  123. keep_child = True
  124. for i, speaker_A in enumerate(speakers):
  125. for j, speaker_B in enumerate(speakers):
  126. if i != j:
  127. z = len(vtc[f"{speaker_A}_vocs_fp_{speaker_B}"])
  128. else:
  129. z = min(
  130. len(vtc[f"{speaker_A}_vocs_explained"]), len(truth[speaker_A])
  131. )
  132. d[f"vtc_{i}_{j}"] = z
  133. if z > len(truth[speaker_B]):
  134. keep_child = False
  135. d[f"truth_{i}"] = len(truth[speaker_A])
  136. d["child"] = child
  137. d["duration"] = ann["duration"].sum() / 2 / 1000
  138. if keep_child:
  139. data.append(d)
  140. return pd.DataFrame(data).assign(
  141. corpus=corpus,
  142. )
  143. def rates(parameters):
  144. corpus = parameters["corpus"]
  145. annotator = parameters["annotator"]
  146. speakers = ["CHI", "OCH", "FEM", "MAL"]
  147. project = ChildProject(parameters["path"])
  148. am = AnnotationManager(project)
  149. am.read()
  150. pipeline = AclewMetrics(
  151. project,
  152. vtc=annotator,
  153. alice=None,
  154. vcm=None,
  155. from_time="10:00:00",
  156. to_time="18:00:00",
  157. by="child_id",
  158. )
  159. metrics = pipeline.extract()
  160. metrics = pd.DataFrame(metrics).assign(corpus=corpus,annotator=annotator)
  161. metrics["duration"] = metrics[f"duration_{annotator}"]/1000/3600
  162. metrics = metrics[metrics["duration"] > 0.01]
  163. speakers = ['CHI', 'OCH', 'FEM', 'MAL']
  164. # metrics.dropna(subset={f"voc_{speaker.lower()}_ph" for speaker in speakers}&set(metrics.columns), inplace=True)
  165. for i, speaker in enumerate(speakers):
  166. # if f"voc_{speaker.lower()}_ph" not in metrics.columns:
  167. # metrics[f"speech_rate_{i}"] = pd.NA
  168. # else:
  169. metrics[f"speech_rate_{i}"] = (metrics[f"voc_{speaker.lower()}_ph"]*(metrics["duration"])).fillna(0).astype(int)
  170. return metrics
  171. def run_model(data, run, model_name):
  172. model = CmdStanModel(
  173. stan_file=f"code/models/{model_name}.stan",
  174. cpp_options={"STAN_THREADS": "TRUE"},
  175. compile="force",
  176. )
  177. fit = model.sample(
  178. data=data,
  179. chains=args.chains,
  180. threads_per_chain=args.threads_per_chain,
  181. iter_sampling=args.samples,
  182. iter_warmup=args.warmup,
  183. step_size=0.1,
  184. # save_profile=True,
  185. # show_console=True,
  186. )
  187. vars = fit.stan_variables()
  188. samples = {}
  189. for (k, v) in vars.items():
  190. samples[k] = v
  191. np.savez_compressed(f"output/aggregates_{run}_{model_name}.npz", **samples)
  192. samples = np.load(f"output/aggregates_{run}_{model_name}.npz")
  193. with open(f"output/aggregates_{run}_{model_name}.pickle", "wb") as f:
  194. pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
  195. print(samples["R2"].mean(axis=0))
  196. return {
  197. "evidence": samples["evidence"].mean(),
  198. "priors": samples["priors"].mean(),
  199. "total_evidence": (samples["evidence"]+samples["priors"]).mean(),
  200. "samples": samples
  201. }
  202. def compile_recordings(corpus):
  203. project = ChildProject(corpus)
  204. am = AnnotationManager(project)
  205. am.read()
  206. project.recordings["age"] = project.compute_ages()
  207. annotations = am.annotations[am.annotations["set"] == "vtc"]
  208. annotations = annotations.merge(
  209. project.recordings,
  210. left_on="recording_filename",
  211. right_on="recording_filename",
  212. how="inner",
  213. )
  214. recs = []
  215. for recording_filename, _annotations in annotations.groupby("recording_filename"):
  216. _annotations = am.get_within_time_range(
  217. _annotations,
  218. TimeInterval(
  219. datetime.datetime(1900, 1, 1, 10, 0),
  220. datetime.datetime(1900, 1, 1, 10 + args.duration, 0),
  221. ),
  222. )
  223. child_id = _annotations["child_id"].max()
  224. age = _annotations["age"].max()
  225. duration = (_annotations["range_offset"] - _annotations["range_onset"]).sum()
  226. if duration < args.duration * 3600 * 1000:
  227. continue
  228. duration = args.duration * 3600 * 1000
  229. _annotations["path"] = _annotations.apply(
  230. lambda r: opj(
  231. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  232. ),
  233. axis=1,
  234. )
  235. missing_annotations = _annotations[~_annotations["path"].map(exists)]
  236. if len(missing_annotations):
  237. datalad.api.get(list(missing_annotations["path"].unique()))
  238. segments = am.get_segments(_annotations)
  239. segments["segment_onset"] -= segments["segment_onset"].min()
  240. segments = segments[segments["segment_onset"] >= 0]
  241. segments = segments[segments["segment_onset"] < duration]
  242. if len(segments) == 0:
  243. continue
  244. segments = segments[segments["speaker_type"].isin(["CHI", "OCH", "FEM", "MAL"])]
  245. rec = {
  246. f"vtc_{i}": len(segments[segments["speaker_type"] == speaker])
  247. for i, speaker in enumerate(speakers)
  248. }
  249. rec["recording"] = recording_filename
  250. rec["children"] = f"{corpus}_{child_id}"
  251. rec["corpus"] = basename(corpus)
  252. rec["age"] = age
  253. recs.append(rec)
  254. recs = pd.DataFrame(recs)
  255. return recs
  256. if __name__ == "__main__":
  257. recs = pd.concat([compile_recordings(corpus) for corpus in args.corpora])
  258. recs["children"] = recs["children"].astype("category").cat.codes.astype(int) + 1
  259. annotators = pd.read_csv("input/annotators.csv")
  260. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  261. with mp.Pool(processes=args.chains*args.threads_per_chain) as pool:
  262. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  263. data = data.sample(frac=1)
  264. duration = data["duration"].sum()
  265. vtc = np.moveaxis(
  266. [[data[f"vtc_{j}_{i}"].values for i in range(4)] for j in range(4)], -1, 0
  267. )
  268. truth = np.transpose([data[f"truth_{i}"].values for i in range(4)])
  269. # speech rates at the child level
  270. annotators = annotators[~annotators['annotator'].str.startswith('eaf_2021')]
  271. with mp.Pool(processes=args.chains*args.threads_per_chain) as pool:
  272. speech_rates = pd.concat(pool.map(rates, annotators.to_dict(orient="records")))
  273. speech_rates.reset_index(inplace=True)
  274. speech_rates = speech_rates.groupby(["corpus", "child_id"]).sample(1)
  275. speech_rate_matrix = np.transpose([speech_rates[f"speech_rate_{i}"].values for i in range(4)])
  276. speech_rates.to_csv("rates.csv")
  277. print(vtc.shape)
  278. data["corpus"] = data["corpus"].astype("category")
  279. corpora = data["corpus"].cat.codes.values
  280. corpora_codes = dict(enumerate(data["corpus"].cat.categories))
  281. corpora_codes = {v: k for k, v in corpora_codes.items()}
  282. confusion_data = {
  283. "n_clips": truth.shape[0],
  284. "n_classes": truth.shape[1],
  285. "n_groups": data["child"].nunique(),
  286. "n_corpora": data["corpus"].nunique(),
  287. "n_validation": max(1, int(truth.shape[0] * args.validation)),
  288. "group": 1 + data["child"].astype("category").cat.codes.values,
  289. "conf_corpus": 1 + corpora,
  290. "truth": truth.astype(int),
  291. "vtc": vtc.astype(int),
  292. "speech_rates": speech_rate_matrix.astype(int),
  293. "group_corpus": 1+speech_rates["corpus"].map(corpora_codes).astype(int).values,
  294. "durations": speech_rates["duration"].values,
  295. "n_rates": len(speech_rates)
  296. }
  297. n_recs = len(recs)
  298. children_corpus = recs.groupby("children").agg(corpus=("corpus", "first")).sort_index()
  299. children_corpus = 1+children_corpus.corpus.map(corpora_codes).astype(int).values
  300. analysis_data = {
  301. "n_recs": n_recs,
  302. "n_children": len(recs["children"].unique()),
  303. "children": recs["children"],
  304. "vocs": np.transpose([recs[f"vtc_{i}"].values for i in range(4)]),
  305. "age": recs["age"],
  306. "corpus": children_corpus,
  307. "duration": args.duration,
  308. }
  309. data = {**analysis_data, **confusion_data}
  310. output = {}
  311. for model_name in args.models:
  312. output[model_name] = run_model(data, args.run, model_name)