enumeration.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. #!/usr/bin/env python
  2. import pandas as pd
  3. import numpy as np
  4. from ChildProject.projects import ChildProject
  5. from ChildProject.annotations import AnnotationManager
  6. from ChildProject.metrics import segments_to_grid, segments_to_annotation
  7. from ChildProject.pipelines.metrics import AclewMetrics
  8. from ChildProject.utils import TimeInterval
  9. from cmdstanpy import CmdStanModel
  10. import arviz as av
  11. import datetime
  12. import matplotlib
  13. from matplotlib import pyplot as plt
  14. matplotlib.use("pgf")
  15. matplotlib.rcParams.update(
  16. {
  17. "pgf.texsystem": "xelatex",
  18. "font.family": "serif",
  19. "font.serif": "Times New Roman",
  20. "text.usetex": True,
  21. "pgf.rcfonts": False,
  22. }
  23. )
  24. from collections import defaultdict
  25. import pickle
  26. import datalad.api
  27. from os.path import join as opj
  28. from os.path import basename, exists
  29. import multiprocessing as mp
  30. from pyannote.core import Annotation, Segment, Timeline
  31. import argparse
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument("--corpora", default=["input/bergelson"], nargs="+")
  34. parser.add_argument("--duration", type=int, help="duration in hours", default=8)
  35. parser.add_argument("--calibration", default="input/annotators.csv")
  36. parser.add_argument("--run")
  37. parser.add_argument("--algo", default="vtc", help="algorithm annotation set")
  38. parser.add_argument("--chains", type=int, default=1)
  39. parser.add_argument("--warmup", type=int, default=400)
  40. parser.add_argument("--samples", type=int, default=1000)
  41. parser.add_argument("--threads-per-chain", type=int, default=4)
  42. parser.add_argument("--models", default=["causal"], nargs="+")
  43. parser.add_argument("--confusion-clips", choices=["recording", "clip"], default="clip")
  44. parser.add_argument("--confusion-groups", choices=["child", "recording"], default="recording")
  45. parser.add_argument("--time-scale", type=int, default=15)
  46. parser.add_argument("--require-clip-age", action="store_true", default=True)
  47. parser.add_argument("--age-threshold", type=int, default=0)
  48. parser.add_argument("--adapt-delta", type=float, default=0.9)
  49. parser.add_argument("--max-treedepth", type=int, default=14)
  50. args = parser.parse_args()
  51. speakers = ["CHI", "OCH", "FEM", "MAL"]
  52. def extrude(self, removed, mode: str = "intersection"):
  53. if isinstance(removed, Segment):
  54. removed = Timeline([removed])
  55. truncating_support = removed.gaps(support=self.extent())
  56. # loose for truncate means strict for crop and vice-versa
  57. if mode == "loose":
  58. mode = "strict"
  59. elif mode == "strict":
  60. mode = "loose"
  61. return self.crop(truncating_support, mode=mode)
  62. def children_siblings(corpus):
  63. siblings = pd.read_csv("input/siblings.csv")
  64. siblings["child_id"] = siblings["child_id"].astype(str)
  65. siblings = siblings[siblings["corpus"]==basename(corpus)].set_index("child_id")
  66. siblings = siblings["n_siblings"].to_dict()
  67. n = defaultdict(lambda: -1, **siblings)
  68. return n
  69. def compute_counts(parameters):
  70. corpus = parameters["corpus"]
  71. annotator = parameters["annotator"]
  72. rural = ((parameters["type"] == "rural")*1) if "type" in parameters else 0
  73. project = ChildProject(parameters["path"])
  74. am = AnnotationManager(project)
  75. am.read()
  76. am.annotations = am.annotations[~am.annotations["annotation_filename"].isna()]
  77. project.recordings["age"] = project.compute_ages()
  78. intersection = AnnotationManager.intersection(am.annotations, [args.algo, annotator])
  79. if len(intersection) == 0:
  80. print(f"No intersection between '{args.algo}' and '{annotator}' in {corpus}!")
  81. return pd.DataFrame()
  82. intersection["path"] = intersection.apply(
  83. lambda r: opj(
  84. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  85. ),
  86. axis=1,
  87. )
  88. missing_annotations = intersection[~intersection["path"].map(exists)]
  89. if len(missing_annotations):
  90. datalad.api.get(list(missing_annotations["path"].unique()))
  91. intersection = intersection.merge(
  92. project.recordings[["recording_filename", "child_id", "age"]], how="left"
  93. )
  94. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  95. intersection["recording"] = corpus + "_" + intersection["recording_filename"].astype(str)
  96. intersection["clip_id"] = (
  97. corpus
  98. + "_"
  99. + intersection["recording_filename"].str.cat(intersection["range_onset"].astype(str))
  100. )
  101. if args.time_scale:
  102. intersection["onset"] = intersection.apply(
  103. lambda r: np.arange(
  104. r["range_onset"], r["range_offset"], args.time_scale * 1000
  105. ),
  106. axis=1,
  107. )
  108. intersection = intersection.explode("onset")
  109. intersection["range_onset"] = intersection["onset"]
  110. intersection["range_offset"] = (
  111. intersection["range_onset"] + args.time_scale * 1000
  112. ).clip(upper=intersection["range_offset"])
  113. intersection["duration"] = (
  114. intersection["range_offset"] - intersection["range_onset"]
  115. )
  116. print(corpus, annotator, (intersection["duration"] / 1000 / 2).sum() / 3600)
  117. data = []
  118. # for child, ann in intersection.groupby("child"):
  119. groupby = {
  120. "recording": ["child_id", "recording"],
  121. "clip": ["child_id", "recording", "range_onset"],
  122. }[args.confusion_clips]
  123. for clip, ann in intersection.groupby(groupby):
  124. child = clip[0]
  125. recording = clip[1]
  126. # print(corpus, child)
  127. segments = am.get_collapsed_segments(ann)
  128. if "speaker_type" not in segments.columns:
  129. continue
  130. segments = segments[segments["speaker_type"].isin(speakers)]
  131. d = {}
  132. for i, speaker_A in enumerate(speakers):
  133. d[f"truth_total_{i}"] = len(segments[(segments["set"] == annotator) & (segments["speaker_type"] == speaker_A)])
  134. d[f"algo_total_{i}"] = len(segments[(segments["set"] == args.algo) & (segments["speaker_type"] == speaker_A)])
  135. d["child"] = child
  136. d["recording"] = recording
  137. d["duration"] = ann["duration"].sum() / 2 / 1000 / 3600
  138. d["clip_id"] = ann["clip_id"].iloc[0]
  139. d["clip_age"] = ann["age"].iloc[0]
  140. data.append(d)
  141. data = pd.DataFrame(data).assign(
  142. corpus=corpus,
  143. clip_rural=rural
  144. )
  145. if args.require_clip_age:
  146. data.dropna(subset=["clip_age"], inplace=True)
  147. return data
  148. def rates(parameters):
  149. corpus = parameters["corpus"]
  150. annotator = parameters["annotator"]
  151. speakers = ["CHI", "OCH", "FEM", "MAL"]
  152. project = ChildProject(parameters["path"])
  153. am = AnnotationManager(project)
  154. am.read()
  155. pipeline = AclewMetrics(
  156. project,
  157. vtc=annotator,
  158. alice=None,
  159. vcm=None,
  160. from_time="10:00:00",
  161. to_time="18:00:00",
  162. by="recording_filename",
  163. threads=args.chains*args.threads_per_chain
  164. )
  165. metrics = pipeline.extract()
  166. metrics = pd.DataFrame(metrics).assign(corpus=corpus, annotator=annotator)
  167. project.recordings["age"] = project.compute_ages()
  168. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  169. children_siblings(corpus)
  170. )
  171. metrics = metrics.merge(
  172. project.recordings[["recording_filename", "age", "siblings"]]
  173. )
  174. metrics["duration"] = metrics[f"duration_{annotator}"] / 1000 / 3600
  175. metrics = metrics[metrics["duration"] > 0.01]
  176. metrics["child"] = corpus + "_" + metrics["child_id"].astype(str)
  177. speakers = ["CHI", "OCH", "FEM", "MAL"]
  178. # metrics.dropna(subset={f"voc_{speaker.lower()}_ph" for speaker in speakers}&set(metrics.columns), inplace=True)
  179. for i, speaker in enumerate(speakers):
  180. # if f"voc_{speaker.lower()}_ph" not in metrics.columns:
  181. # metrics[f"speech_rate_{i}"] = pd.NA
  182. # else:
  183. metrics[f"speech_rate_{i}"] = (
  184. (metrics[f"voc_{speaker.lower()}_ph"] * (metrics["duration"]))
  185. .fillna(0)
  186. .astype(int)
  187. )
  188. return metrics
  189. def run_model(data, run, model_name):
  190. model = CmdStanModel(
  191. stan_file=f"code/models/{model_name}.stan",
  192. cpp_options={"STAN_THREADS": "TRUE"},
  193. )
  194. fit = model.sample(
  195. data=data,
  196. chains=args.chains,
  197. threads_per_chain=args.threads_per_chain,
  198. iter_sampling=args.samples,
  199. iter_warmup=args.warmup,
  200. step_size=0.1,
  201. max_treedepth=args.max_treedepth,
  202. adapt_delta=args.adapt_delta,
  203. # save_profile=True,
  204. show_console=True,
  205. )
  206. vars = fit.stan_variables()
  207. samples = {}
  208. for (k, v) in vars.items():
  209. samples[k] = v
  210. np.savez_compressed(f"output/aggregates_{run}_{model_name}.npz", **samples)
  211. samples = np.load(f"output/aggregates_{run}_{model_name}.npz")
  212. with open(f"output/aggregates_{run}_{model_name}.pickle", "wb") as f:
  213. pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
  214. print(fit.diagnose())
  215. # fit_av = av.from_cmdstanpy(fit)
  216. # axes = av.plot_pair(
  217. # fit_av,
  218. # var_names=["alpha_dev", "sigma_dev", "beta_dev", "alpha_pop", "beta_sib_och", "beta_sib_adu"],
  219. # divergences=True
  220. # )
  221. # fig = axes.ravel()[0].figure
  222. # fig.savefig("output/pairplot.eps", bbox_inches="tight")
  223. # fig.savefig("output/pairplot.png", bbox_inches="tight", dpi=1080)
  224. return {"samples": samples}
  225. def compile_recordings(corpus):
  226. project = ChildProject(corpus)
  227. am = AnnotationManager(project)
  228. am.read()
  229. if "age" in project.children.columns:
  230. project.children.drop(columns=["age"], inplace=True)
  231. project.recordings["age"] = project.compute_ages()
  232. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  233. children_siblings(corpus)
  234. )
  235. annotations = am.annotations[am.annotations["set"] == args.algo]
  236. annotations = annotations.merge(
  237. project.recordings,
  238. left_on="recording_filename",
  239. right_on="recording_filename",
  240. how="inner",
  241. )
  242. annotations = annotations.merge(
  243. project.children,
  244. left_on="child_id",
  245. right_on="child_id",
  246. how="inner"
  247. )
  248. if "normative" in annotations.columns:
  249. print("filtering out non-normative kids")
  250. annotations = annotations[annotations["normative"] != "N"]
  251. recs = []
  252. for recording_filename, _annotations in annotations.groupby("recording_filename"):
  253. _annotations = am.get_within_time_range(
  254. _annotations,
  255. TimeInterval(
  256. datetime.datetime(1900, 1, 1, 10, 0),
  257. datetime.datetime(1900, 1, 1, 10 + args.duration, 0),
  258. ),
  259. )
  260. child_id = _annotations["child_id"].max()
  261. age = _annotations["age"].max()
  262. duration = (_annotations["range_offset"] - _annotations["range_onset"]).sum()
  263. siblings = _annotations["siblings"].max()
  264. if duration < args.duration * 3600 * 1000:
  265. continue
  266. duration = args.duration * 3600 * 1000
  267. _annotations["path"] = _annotations.apply(
  268. lambda r: opj(
  269. project.path,
  270. "annotations",
  271. r["set"],
  272. "converted",
  273. r["annotation_filename"],
  274. ),
  275. axis=1,
  276. )
  277. missing_annotations = _annotations[~_annotations["path"].map(exists)]
  278. if len(missing_annotations):
  279. datalad.api.get(list(missing_annotations["path"].unique()))
  280. segments = am.get_segments(_annotations)
  281. segments["segment_onset"] -= segments["segment_onset"].min()
  282. segments = segments[segments["segment_onset"] >= 0]
  283. segments = segments[segments["segment_onset"] < duration]
  284. if len(segments) == 0:
  285. continue
  286. segments = segments[segments["speaker_type"].isin(["CHI", "OCH", "FEM", "MAL"])]
  287. rec = {
  288. f"algo_{i}": len(segments[segments["speaker_type"] == speaker])
  289. for i, speaker in enumerate(speakers)
  290. }
  291. rec["recording"] = recording_filename
  292. rec["children"] = f"{corpus}_{child_id}"
  293. rec["corpus"] = basename(corpus)
  294. rec["age"] = age
  295. rec["siblings"] = siblings
  296. recs.append(rec)
  297. recs = pd.DataFrame(recs)
  298. return recs
  299. if __name__ == "__main__":
  300. recs = pd.concat([compile_recordings(corpus) for corpus in args.corpora])
  301. recs["children"] = recs["children"].astype("category").cat.codes.astype(int) + 1
  302. annotators = pd.read_csv(args.calibration)
  303. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  304. with mp.Pool(processes=args.chains * args.threads_per_chain) as pool:
  305. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  306. corpora = sorted(list(set(data["corpus"].unique()) | set(recs["corpus"].unique())))
  307. corpora_map = {
  308. corpus: i
  309. for i, corpus in enumerate(corpora)
  310. }
  311. print(corpora_map)
  312. duration = data["duration"].sum()
  313. clip_duration = data["duration"].values
  314. clip_age = data["clip_age"].values
  315. clip_rural = data["clip_rural"].values
  316. truth_total = np.transpose([data[f"truth_total_{i}"].values for i in range(4)])
  317. algo_total = np.transpose([data[f"algo_total_{i}"].values for i in range(4)])
  318. # speech rates at the child level
  319. annotators = annotators[~annotators["annotator"].str.startswith("eaf_2021")]
  320. speech_rates = pd.concat([
  321. rates(annotator) for annotator in annotators.to_dict(orient="records")
  322. ])
  323. speech_rates["corpus_code"] = speech_rates["corpus"].map(corpora_map)
  324. speech_rates.dropna(subset=["corpus_code"], inplace=True)
  325. speech_rates.reset_index(inplace=True)
  326. speech_rate_matrix = np.transpose(
  327. [speech_rates[f"speech_rate_{i}"].values for i in range(4)]
  328. )
  329. speech_rate_age = speech_rates["age"].values
  330. speech_rate_siblings = speech_rates["siblings"].values.astype(int)
  331. speech_rates.to_csv("rates.csv")
  332. data["corpus"] = data["corpus"].astype("category")
  333. confusion_data = {
  334. "n_clips": truth_total.shape[0],
  335. "n_classes": truth_total.shape[1],
  336. "n_groups": data[args.confusion_groups].nunique(),
  337. "n_corpora": len(corpora),
  338. "n_validation": 0,
  339. "group": 1 + data[args.confusion_groups].astype("category").cat.codes.values,
  340. "conf_corpus": 1 + data["corpus"].map(corpora_map).astype(int).values,
  341. "truth_total": truth_total.astype(int),
  342. "algo_total": algo_total.astype(int),
  343. "clip_duration": clip_duration,
  344. "clip_age": clip_age if args.age_threshold == 0 else np.minimum(clip_age, args.age_threshold),
  345. "clip_rural": clip_rural,
  346. "clip_id": 1 + data["clip_id"].astype("category").cat.codes.values,
  347. "n_unique_clips": data["clip_id"].nunique(),
  348. "speech_rates": speech_rate_matrix.astype(int),
  349. "speech_rate_age": speech_rate_age if args.age_threshold == 0 else np.minimum(speech_rate_age, args.age_threshold),
  350. "speech_rate_siblings": speech_rate_siblings,
  351. "group_corpus": (
  352. 1 + speech_rates["corpus_code"].astype(int).values
  353. ),
  354. "speech_rate_child": 1+speech_rates["child"].astype("category").cat.codes.values,
  355. "n_speech_rate_children": speech_rates["child"].nunique(),
  356. "durations": speech_rates["duration"].values,
  357. "n_rates": len(speech_rates),
  358. }
  359. n_recs = len(recs)
  360. children_corpus = (
  361. recs.groupby("children").agg(corpus=("corpus", "first")).sort_index()
  362. )
  363. children_corpus = 1 + children_corpus.corpus.map(corpora_map).astype(int).values
  364. analysis_data = {
  365. "n_recs": n_recs,
  366. "n_children": len(recs["children"].unique()),
  367. "children": recs["children"],
  368. "vocs": np.transpose([recs[f"algo_{i}"].values for i in range(4)]),
  369. "age": recs["age"] if args.age_threshold == 0 else np.minimum(recs["age"], args.age_threshold),
  370. "siblings": recs["siblings"].astype(int),
  371. "corpus": children_corpus,
  372. "recs_duration": args.duration,
  373. }
  374. data = {**analysis_data, **confusion_data}
  375. data["threads"] = args.threads_per_chain
  376. print(data)
  377. output = {}
  378. for model_name in args.models:
  379. output[model_name] = run_model(data, args.run, model_name)