enumeration.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. #!/usr/bin/env python
  2. import pandas as pd
  3. import numpy as np
  4. from ChildProject.projects import ChildProject
  5. from ChildProject.annotations import AnnotationManager
  6. from ChildProject.metrics import segments_to_grid, segments_to_annotation
  7. from ChildProject.pipelines.metrics import AclewMetrics
  8. from ChildProject.utils import TimeInterval
  9. from cmdstanpy import CmdStanModel
  10. import datetime
  11. import matplotlib
  12. from matplotlib import pyplot as plt
  13. matplotlib.use("pgf")
  14. matplotlib.rcParams.update(
  15. {
  16. "pgf.texsystem": "xelatex",
  17. "font.family": "serif",
  18. "font.serif": "Times New Roman",
  19. "text.usetex": True,
  20. "pgf.rcfonts": False,
  21. }
  22. )
  23. from collections import defaultdict
  24. import pickle
  25. import datalad.api
  26. from os.path import join as opj
  27. from os.path import basename, exists
  28. import multiprocessing as mp
  29. from pyannote.core import Annotation, Segment, Timeline
  30. import argparse
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument("--corpora", default=["input/bergelson"], nargs="+")
  33. parser.add_argument("--duration", type=int, help="duration in hours", default=8)
  34. parser.add_argument("--run")
  35. parser.add_argument("--chains", type=int, default=1)
  36. parser.add_argument("--warmup", type=int, default=250)
  37. parser.add_argument("--samples", type=int, default=1000)
  38. parser.add_argument("--threads-per-chain", type=int, default=4)
  39. parser.add_argument("--models", default=["causal"], nargs="+")
  40. parser.add_argument("--confusion-clips", choices=["recording", "clip"], default="clip")
  41. parser.add_argument("--confusion-groups", choices=["child", "recording"], default="recording")
  42. parser.add_argument("--time-scale", type=int, default=0)
  43. args = parser.parse_args()
  44. speakers = ["CHI", "OCH", "FEM", "MAL"]
  45. def extrude(self, removed, mode: str = "intersection"):
  46. if isinstance(removed, Segment):
  47. removed = Timeline([removed])
  48. truncating_support = removed.gaps(support=self.extent())
  49. # loose for truncate means strict for crop and vice-versa
  50. if mode == "loose":
  51. mode = "strict"
  52. elif mode == "strict":
  53. mode = "loose"
  54. return self.crop(truncating_support, mode=mode)
  55. def children_siblings(corpus):
  56. siblings = pd.read_csv("input/siblings.csv")
  57. siblings = siblings[siblings["corpus"]==corpus].set_index("child_id")
  58. siblings["child_id"] = siblings["child_id"].astype(str)
  59. siblings = siblings["n_siblings"].to_dict()
  60. n = defaultdict(lambda: -1, **siblings)
  61. return n
  62. def compute_counts(parameters):
  63. corpus = parameters["corpus"]
  64. annotator = parameters["annotator"]
  65. project = ChildProject(parameters["path"])
  66. am = AnnotationManager(project)
  67. am.read()
  68. project.recordings["age"] = project.compute_ages()
  69. intersection = AnnotationManager.intersection(am.annotations, ["vtc", annotator])
  70. intersection["path"] = intersection.apply(
  71. lambda r: opj(
  72. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  73. ),
  74. axis=1,
  75. )
  76. datalad.api.get(list(intersection["path"].unique()))
  77. intersection = intersection.merge(
  78. project.recordings[["recording_filename", "child_id", "age"]], how="left"
  79. )
  80. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  81. intersection["recording"] = corpus + "_" + intersection["recording_filename"].astype(str)
  82. intersection["clip_id"] = (
  83. corpus
  84. + "_"
  85. + intersection["recording_filename"].str.cat(intersection["range_onset"].astype(str))
  86. )
  87. if args.time_scale:
  88. intersection["onset"] = intersection.apply(
  89. lambda r: np.arange(
  90. r["range_onset"], r["range_offset"], args.time_scale * 1000
  91. ),
  92. axis=1,
  93. )
  94. intersection = intersection.explode("onset")
  95. intersection["range_onset"] = intersection["onset"]
  96. intersection["range_offset"] = (
  97. intersection["range_onset"] + args.time_scale * 1000
  98. ).clip(upper=intersection["range_offset"])
  99. intersection["duration"] = (
  100. intersection["range_offset"] - intersection["range_onset"]
  101. )
  102. print(corpus, annotator, (intersection["duration"] / 1000 / 2).sum() / 3600)
  103. data = []
  104. # for child, ann in intersection.groupby("child"):
  105. groupby = {
  106. "recording": ["child_id", "recording"],
  107. "clip": ["child_id", "recording", "range_onset"],
  108. }[args.confusion_clips]
  109. for clip, ann in intersection.groupby(groupby):
  110. child = clip[0]
  111. recording = clip[1]
  112. # print(corpus, child)
  113. segments = am.get_collapsed_segments(ann)
  114. if "speaker_type" not in segments.columns:
  115. continue
  116. segments = segments[segments["speaker_type"].isin(speakers)]
  117. vtc = {
  118. speaker: segments_to_annotation(
  119. segments[
  120. (segments["set"] == "vtc") & (segments["speaker_type"] == speaker)
  121. ],
  122. "speaker_type",
  123. )
  124. # .support(collar=200)
  125. .get_timeline()
  126. for speaker in speakers
  127. }
  128. truth = {
  129. speaker: segments_to_annotation(
  130. segments[
  131. (segments["set"] == annotator)
  132. & (segments["speaker_type"] == speaker)
  133. ],
  134. "speaker_type",
  135. )
  136. # .support(collar=200)
  137. .get_timeline()
  138. for speaker in speakers
  139. }
  140. d = {}
  141. for i, speaker_A in enumerate(speakers):
  142. d[f"truth_total_{i}"] = len(truth[speaker_A])
  143. d[f"vtc_total_{i}"] = len(vtc[speaker_A])
  144. d["child"] = child
  145. d["recording"] = recording
  146. d["duration"] = ann["duration"].sum() / 2 / 1000 / 3600
  147. d["clip_id"] = ann["clip_id"].iloc[0]
  148. d["clip_age"] = ann["age"].iloc[0]
  149. data.append(d)
  150. return pd.DataFrame(data).assign(
  151. corpus=corpus,
  152. )
  153. def rates(parameters):
  154. corpus = parameters["corpus"]
  155. annotator = parameters["annotator"]
  156. speakers = ["CHI", "OCH", "FEM", "MAL"]
  157. project = ChildProject(parameters["path"])
  158. am = AnnotationManager(project)
  159. am.read()
  160. pipeline = AclewMetrics(
  161. project,
  162. vtc=annotator,
  163. alice=None,
  164. vcm=None,
  165. from_time="10:00:00",
  166. to_time="18:00:00",
  167. by="recording_filename",
  168. )
  169. metrics = pipeline.extract()
  170. metrics = pd.DataFrame(metrics).assign(corpus=corpus, annotator=annotator)
  171. project.recordings["age"] = project.compute_ages()
  172. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  173. children_siblings(corpus)
  174. )
  175. metrics = metrics.merge(
  176. project.recordings[["recording_filename", "age", "siblings"]]
  177. )
  178. metrics["duration"] = metrics[f"duration_{annotator}"] / 1000 / 3600
  179. metrics = metrics[metrics["duration"] > 0.01]
  180. metrics["child"] = corpus + "_" + metrics["child_id"].astype(str)
  181. speakers = ["CHI", "OCH", "FEM", "MAL"]
  182. # metrics.dropna(subset={f"voc_{speaker.lower()}_ph" for speaker in speakers}&set(metrics.columns), inplace=True)
  183. for i, speaker in enumerate(speakers):
  184. # if f"voc_{speaker.lower()}_ph" not in metrics.columns:
  185. # metrics[f"speech_rate_{i}"] = pd.NA
  186. # else:
  187. metrics[f"speech_rate_{i}"] = (
  188. (metrics[f"voc_{speaker.lower()}_ph"] * (metrics["duration"]))
  189. .fillna(0)
  190. .astype(int)
  191. )
  192. return metrics
  193. def run_model(data, run, model_name):
  194. model = CmdStanModel(
  195. stan_file=f"code/models/{model_name}.stan",
  196. cpp_options={"STAN_THREADS": "TRUE"},
  197. compile="force",
  198. )
  199. fit = model.sample(
  200. data=data,
  201. chains=args.chains,
  202. threads_per_chain=args.threads_per_chain,
  203. iter_sampling=args.samples,
  204. iter_warmup=args.warmup,
  205. step_size=0.1,
  206. max_treedepth=12,
  207. # save_profile=True,
  208. show_console=True,
  209. )
  210. vars = fit.stan_variables()
  211. samples = {}
  212. for (k, v) in vars.items():
  213. samples[k] = v
  214. np.savez_compressed(f"output/aggregates_{run}_{model_name}.npz", **samples)
  215. samples = np.load(f"output/aggregates_{run}_{model_name}.npz")
  216. with open(f"output/aggregates_{run}_{model_name}.pickle", "wb") as f:
  217. pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
  218. return {"samples": samples}
  219. def compile_recordings(corpus):
  220. project = ChildProject(corpus)
  221. am = AnnotationManager(project)
  222. am.read()
  223. project.recordings["age"] = project.compute_ages()
  224. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  225. children_siblings(corpus)
  226. )
  227. annotations = am.annotations[am.annotations["set"] == "vtc"]
  228. annotations = annotations.merge(
  229. project.recordings,
  230. left_on="recording_filename",
  231. right_on="recording_filename",
  232. how="inner",
  233. )
  234. annotations = annotations.merge(
  235. project.children,
  236. left_on="child_id",
  237. right_on="child_id",
  238. how="inner"
  239. )
  240. if "normative" in annotations.columns:
  241. print("filtering out non-normative kids")
  242. annotations = annotations[annotations["normative"] != "N"]
  243. recs = []
  244. for recording_filename, _annotations in annotations.groupby("recording_filename"):
  245. _annotations = am.get_within_time_range(
  246. _annotations,
  247. TimeInterval(
  248. datetime.datetime(1900, 1, 1, 10, 0),
  249. datetime.datetime(1900, 1, 1, 10 + args.duration, 0),
  250. ),
  251. )
  252. child_id = _annotations["child_id"].max()
  253. age = _annotations["age"].max()
  254. duration = (_annotations["range_offset"] - _annotations["range_onset"]).sum()
  255. siblings = _annotations["siblings"].max()
  256. if duration < args.duration * 3600 * 1000:
  257. continue
  258. duration = args.duration * 3600 * 1000
  259. _annotations["path"] = _annotations.apply(
  260. lambda r: opj(
  261. project.path,
  262. "annotations",
  263. r["set"],
  264. "converted",
  265. r["annotation_filename"],
  266. ),
  267. axis=1,
  268. )
  269. missing_annotations = _annotations[~_annotations["path"].map(exists)]
  270. if len(missing_annotations):
  271. datalad.api.get(list(missing_annotations["path"].unique()))
  272. segments = am.get_segments(_annotations)
  273. segments["segment_onset"] -= segments["segment_onset"].min()
  274. segments = segments[segments["segment_onset"] >= 0]
  275. segments = segments[segments["segment_onset"] < duration]
  276. if len(segments) == 0:
  277. continue
  278. segments = segments[segments["speaker_type"].isin(["CHI", "OCH", "FEM", "MAL"])]
  279. rec = {
  280. f"vtc_{i}": len(segments[segments["speaker_type"] == speaker])
  281. for i, speaker in enumerate(speakers)
  282. }
  283. rec["recording"] = recording_filename
  284. rec["children"] = f"{corpus}_{child_id}"
  285. rec["corpus"] = basename(corpus)
  286. rec["age"] = age
  287. rec["siblings"] = siblings
  288. recs.append(rec)
  289. recs = pd.DataFrame(recs)
  290. return recs
  291. if __name__ == "__main__":
  292. recs = pd.concat([compile_recordings(corpus) for corpus in args.corpora])
  293. recs["children"] = recs["children"].astype("category").cat.codes.astype(int) + 1
  294. annotators = pd.read_csv("input/annotators.csv")
  295. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  296. with mp.Pool(processes=args.chains * args.threads_per_chain) as pool:
  297. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  298. corpora = list(set(data["corpus"].unique()) | set(recs["corpus"].unique()))
  299. corpora_map = {
  300. corpus: i
  301. for i, corpus in enumerate(corpora)
  302. }
  303. print(corpora_map)
  304. duration = data["duration"].sum()
  305. clip_duration = data["duration"].values
  306. clip_age = data["clip_age"].values
  307. truth_total = np.transpose([data[f"truth_total_{i}"].values for i in range(4)])
  308. vtc_total = np.transpose([data[f"vtc_total_{i}"].values for i in range(4)])
  309. # speech rates at the child level
  310. annotators = annotators[~annotators["annotator"].str.startswith("eaf_2021")]
  311. with mp.Pool(processes=args.chains * args.threads_per_chain) as pool:
  312. speech_rates = pd.concat(pool.map(rates, annotators.to_dict(orient="records")))
  313. speech_rates.reset_index(inplace=True)
  314. speech_rate_matrix = np.transpose(
  315. [speech_rates[f"speech_rate_{i}"].values for i in range(4)]
  316. )
  317. speech_rate_age = speech_rates["age"].values
  318. speech_rate_siblings = speech_rates["siblings"].values.astype(int)
  319. speech_rates.to_csv("rates.csv")
  320. data["corpus"] = data["corpus"].astype("category")
  321. confusion_data = {
  322. "n_clips": truth_total.shape[0],
  323. "n_classes": truth_total.shape[1],
  324. "n_groups": data[args.confusion_groups].nunique(),
  325. "n_corpora": len(corpora),
  326. "n_validation": 0,
  327. "group": 1 + data[args.confusion_groups].astype("category").cat.codes.values,
  328. "conf_corpus": 1 + data["corpus"].map(corpora_map).astype(int).values,
  329. "truth_total": truth_total.astype(int),
  330. "vtc_total": vtc_total.astype(int),
  331. "clip_duration": clip_duration,
  332. "clip_age": clip_age,
  333. "clip_id": 1 + data["clip_id"].astype("category").cat.codes.values,
  334. "n_unique_clips": data["clip_id"].nunique(),
  335. "speech_rates": speech_rate_matrix.astype(int),
  336. "speech_rate_age": speech_rate_age,
  337. "speech_rate_siblings": speech_rate_siblings,
  338. "group_corpus": (
  339. 1 + speech_rates["corpus"].map(corpora_map).astype(int).values
  340. ),
  341. "speech_rate_child": 1+speech_rates["child"].astype("category").cat.codes.values,
  342. "n_speech_rate_children": speech_rates["child"].nunique(),
  343. "durations": speech_rates["duration"].values,
  344. "n_rates": len(speech_rates),
  345. }
  346. n_recs = len(recs)
  347. children_corpus = (
  348. recs.groupby("children").agg(corpus=("corpus", "first")).sort_index()
  349. )
  350. children_corpus = 1 + children_corpus.corpus.map(corpora_map).astype(int).values
  351. analysis_data = {
  352. "n_recs": n_recs,
  353. "n_children": len(recs["children"].unique()),
  354. "children": recs["children"],
  355. "vocs": np.transpose([recs[f"vtc_{i}"].values for i in range(4)]),
  356. "age": recs["age"],
  357. "siblings": recs["siblings"].astype(int),
  358. "corpus": children_corpus,
  359. "recs_duration": args.duration,
  360. }
  361. data = {**analysis_data, **confusion_data}
  362. data["threads"] = args.threads_per_chain
  363. print(data)
  364. output = {}
  365. for model_name in args.models:
  366. output[model_name] = run_model(data, args.run, model_name)