enumeration_combined.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. #!/usr/bin/env python
  2. import pandas as pd
  3. import numpy as np
  4. from ChildProject.projects import ChildProject
  5. from ChildProject.annotations import AnnotationManager
  6. from ChildProject.metrics import segments_to_grid, segments_to_annotation
  7. from ChildProject.pipelines.metrics import AclewMetrics
  8. from ChildProject.utils import TimeInterval
  9. from cmdstanpy import CmdStanModel
  10. import datetime
  11. import matplotlib
  12. from matplotlib import pyplot as plt
  13. matplotlib.use("pgf")
  14. matplotlib.rcParams.update(
  15. {
  16. "pgf.texsystem": "xelatex",
  17. "font.family": "serif",
  18. "font.serif": "Times New Roman",
  19. "text.usetex": True,
  20. "pgf.rcfonts": False,
  21. }
  22. )
  23. from collections import defaultdict
  24. import pickle
  25. import datalad.api
  26. from os.path import join as opj
  27. from os.path import basename, exists
  28. import multiprocessing as mp
  29. from pyannote.core import Annotation, Segment, Timeline
  30. import argparse
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument("--corpora", default=["input/bergelson"], nargs="+")
  33. parser.add_argument("--duration", type=int, help="duration in hours", default=8)
  34. parser.add_argument("--calibration", default="input/annotators.csv")
  35. parser.add_argument("--run")
  36. parser.add_argument("--algo1", default="vtc", help="algorithm1 annotation set")
  37. parser.add_argument("--algo2", default="its", help="algorithm2 annotation set")
  38. parser.add_argument("--chains", type=int, default=1)
  39. parser.add_argument("--warmup", type=int, default=250)
  40. parser.add_argument("--samples", type=int, default=1000)
  41. parser.add_argument("--threads-per-chain", type=int, default=4)
  42. parser.add_argument("--models", default=["causal"], nargs="+")
  43. parser.add_argument("--confusion-clips", choices=["recording", "clip"], default="clip")
  44. parser.add_argument("--confusion-groups", choices=["child", "recording"], default="recording")
  45. parser.add_argument("--time-scale", type=int, default=0)
  46. parser.add_argument("--require-clip-age", action="store_true", default=True)
  47. parser.add_argument("--adapt-delta", type=float, default=0.9)
  48. parser.add_argument("--max-treedepth", type=int, default=14)
  49. args = parser.parse_args()
  50. speakers = ["CHI", "OCH", "FEM", "MAL"]
  51. def extrude(self, removed, mode: str = "intersection"):
  52. if isinstance(removed, Segment):
  53. removed = Timeline([removed])
  54. truncating_support = removed.gaps(support=self.extent())
  55. # loose for truncate means strict for crop and vice-versa
  56. if mode == "loose":
  57. mode = "strict"
  58. elif mode == "strict":
  59. mode = "loose"
  60. return self.crop(truncating_support, mode=mode)
  61. def children_siblings(corpus):
  62. siblings = pd.read_csv("input/siblings.csv")
  63. siblings["child_id"] = siblings["child_id"].astype(str)
  64. siblings = siblings[siblings["corpus"]==basename(corpus)].set_index("child_id")
  65. siblings = siblings["n_siblings"].to_dict()
  66. n = defaultdict(lambda: -1, **siblings)
  67. return n
  68. def compute_counts(parameters):
  69. corpus = parameters["corpus"]
  70. annotator = parameters["annotator"]
  71. rural = ((parameters["type"] == "rural")*1) if "type" in parameters else 0
  72. project = ChildProject(parameters["path"])
  73. am = AnnotationManager(project)
  74. am.read()
  75. project.recordings["age"] = project.compute_ages()
  76. intersection = AnnotationManager.intersection(am.annotations, [args.algo1, args.algo2, annotator])
  77. if len(intersection) == 0:
  78. print(f"No intersection between '{args.algo1}', '{args.algo2}' and '{annotator}' in {corpus}!")
  79. return pd.DataFrame()
  80. intersection["path"] = intersection.apply(
  81. lambda r: opj(
  82. project.path, "annotations", r["set"], "converted", r["annotation_filename"]
  83. ),
  84. axis=1,
  85. )
  86. missing_annotations = intersection[~intersection["path"].map(exists)]
  87. if len(missing_annotations):
  88. datalad.api.get(list(missing_annotations["path"].unique()))
  89. intersection = intersection.merge(
  90. project.recordings[["recording_filename", "child_id", "age"]], how="left"
  91. )
  92. intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
  93. intersection["recording"] = corpus + "_" + intersection["recording_filename"].astype(str)
  94. intersection["clip_id"] = (
  95. corpus
  96. + "_"
  97. + intersection["recording_filename"].str.cat(intersection["range_onset"].astype(str))
  98. )
  99. if args.time_scale:
  100. intersection["onset"] = intersection.apply(
  101. lambda r: np.arange(
  102. r["range_onset"], r["range_offset"], args.time_scale * 1000
  103. ),
  104. axis=1,
  105. )
  106. intersection = intersection.explode("onset")
  107. intersection["range_onset"] = intersection["onset"]
  108. intersection["range_offset"] = (
  109. intersection["range_onset"] + args.time_scale * 1000
  110. ).clip(upper=intersection["range_offset"])
  111. intersection["duration"] = (
  112. intersection["range_offset"] - intersection["range_onset"]
  113. )
  114. print(corpus, annotator, (intersection["duration"] / 1000 / 3).sum() / 3600)
  115. data = []
  116. # for child, ann in intersection.groupby("child"):
  117. groupby = {
  118. "recording": ["child_id", "recording"],
  119. "clip": ["child_id", "recording", "range_onset"],
  120. }[args.confusion_clips]
  121. for clip, ann in intersection.groupby(groupby):
  122. child = clip[0]
  123. recording = clip[1]
  124. # print(corpus, child)
  125. segments = am.get_collapsed_segments(ann)
  126. if "speaker_type" not in segments.columns:
  127. continue
  128. segments = segments[segments["speaker_type"].isin(speakers)]
  129. algo1 = {
  130. speaker: segments_to_annotation(
  131. segments[
  132. (segments["set"] == args.algo1) & (segments["speaker_type"] == speaker)
  133. ],
  134. "speaker_type",
  135. )
  136. # .support(collar=200)
  137. .get_timeline()
  138. for speaker in speakers
  139. }
  140. algo2 = {
  141. speaker: segments_to_annotation(
  142. segments[
  143. (segments["set"] == args.algo2) & (segments["speaker_type"] == speaker)
  144. ],
  145. "speaker_type",
  146. )
  147. # .support(collar=200)
  148. .get_timeline()
  149. for speaker in speakers
  150. }
  151. truth = {
  152. speaker: segments_to_annotation(
  153. segments[
  154. (segments["set"] == annotator)
  155. & (segments["speaker_type"] == speaker)
  156. ],
  157. "speaker_type",
  158. )
  159. # .support(collar=200)
  160. .get_timeline()
  161. for speaker in speakers
  162. }
  163. d = {}
  164. for i, speaker_A in enumerate(speakers):
  165. d[f"truth_total_{i}"] = len(truth[speaker_A])
  166. d[f"algo1_total_{i}"] = len(algo1[speaker_A])
  167. d[f"algo2_total_{i}"] = len(algo2[speaker_A])
  168. d["child"] = child
  169. d["recording"] = recording
  170. d["duration"] = ann["duration"].sum() / 3 / 1000 / 3600
  171. d["clip_id"] = ann["clip_id"].iloc[0]
  172. d["clip_age"] = ann["age"].iloc[0]
  173. data.append(d)
  174. data = pd.DataFrame(data).assign(
  175. corpus=corpus,
  176. clip_rural=rural
  177. )
  178. if args.require_clip_age:
  179. data.dropna(subset=["clip_age"], inplace=True)
  180. print(data)
  181. return data
  182. def rates(parameters):
  183. corpus = parameters["corpus"]
  184. annotator = parameters["annotator"]
  185. speakers = ["CHI", "OCH", "FEM", "MAL"]
  186. project = ChildProject(parameters["path"])
  187. am = AnnotationManager(project)
  188. am.read()
  189. pipeline = AclewMetrics(
  190. project,
  191. vtc=annotator,
  192. alice=None,
  193. vcm=None,
  194. from_time="10:00:00",
  195. to_time="18:00:00",
  196. by="recording_filename",
  197. threads=args.chains
  198. )
  199. metrics = pipeline.extract()
  200. metrics = pd.DataFrame(metrics).assign(corpus=corpus, annotator=annotator)
  201. project.recordings["age"] = project.compute_ages()
  202. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  203. children_siblings(corpus)
  204. )
  205. metrics = metrics.merge(
  206. project.recordings[["recording_filename", "age", "siblings"]]
  207. )
  208. metrics["duration"] = metrics[f"duration_{annotator}"] / 1000 / 3600
  209. metrics = metrics[metrics["duration"] > 0.01]
  210. metrics["child"] = corpus + "_" + metrics["child_id"].astype(str)
  211. speakers = ["CHI", "OCH", "FEM", "MAL"]
  212. # metrics.dropna(subset={f"voc_{speaker.lower()}_ph" for speaker in speakers}&set(metrics.columns), inplace=True)
  213. for i, speaker in enumerate(speakers):
  214. # if f"voc_{speaker.lower()}_ph" not in metrics.columns:
  215. # metrics[f"speech_rate_{i}"] = pd.NA
  216. # else:
  217. metrics[f"speech_rate_{i}"] = (
  218. (metrics[f"voc_{speaker.lower()}_ph"] * (metrics["duration"]))
  219. .fillna(0)
  220. .astype(int)
  221. )
  222. return metrics
  223. def run_model(data, run, model_name):
  224. model = CmdStanModel(
  225. stan_file=f"code/models/{model_name}.stan",
  226. cpp_options={"STAN_THREADS": "TRUE"},
  227. )
  228. fit = model.sample(
  229. data=data,
  230. chains=args.chains,
  231. threads_per_chain=args.threads_per_chain,
  232. iter_sampling=args.samples,
  233. iter_warmup=args.warmup,
  234. step_size=0.1,
  235. max_treedepth=args.max_treedepth,
  236. adapt_delta=args.adapt_delta,
  237. # save_profile=True,
  238. show_console=True,
  239. )
  240. vars = fit.stan_variables()
  241. samples = {}
  242. for (k, v) in vars.items():
  243. samples[k] = v
  244. np.savez_compressed(f"output/aggregates_{run}_{model_name}.npz", **samples)
  245. samples = np.load(f"output/aggregates_{run}_{model_name}.npz")
  246. with open(f"output/aggregates_{run}_{model_name}.pickle", "wb") as f:
  247. pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
  248. print(fit.diagnose())
  249. return {"samples": samples}
  250. def compile_recordings(corpus):
  251. project = ChildProject(corpus)
  252. am = AnnotationManager(project)
  253. am.read()
  254. project.recordings["age"] = project.compute_ages()
  255. project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
  256. children_siblings(corpus)
  257. )
  258. annotations = am.annotations[am.annotations["set"].isin([args.algo1, args.algo2])]
  259. annotations = annotations.merge(
  260. project.recordings,
  261. left_on="recording_filename",
  262. right_on="recording_filename",
  263. how="inner",
  264. )
  265. annotations = annotations.merge(
  266. project.children,
  267. left_on="child_id",
  268. right_on="child_id",
  269. how="inner"
  270. )
  271. if "normative" in annotations.columns:
  272. print("filtering out non-normative kids")
  273. annotations = annotations[annotations["normative"] != "N"]
  274. recs = []
  275. for recording_filename, _annotations in annotations.groupby("recording_filename"):
  276. _annotations = am.get_within_time_range(
  277. _annotations,
  278. TimeInterval(
  279. datetime.datetime(1900, 1, 1, 10, 0),
  280. datetime.datetime(1900, 1, 1, 10 + args.duration, 0),
  281. ),
  282. )
  283. if len(_annotations) == 0:
  284. continue
  285. child_id = _annotations["child_id"].max()
  286. age = _annotations["age"].max()
  287. siblings = _annotations["siblings"].max()
  288. _annotations["duration"] = _annotations["range_offset"]-_annotations["range_onset"]
  289. durations = _annotations.groupby("set")["duration"].sum()
  290. if (durations < args.duration * 3600 * 1000).any():
  291. continue
  292. duration = args.duration * 3600 * 1000
  293. _annotations["path"] = _annotations.apply(
  294. lambda r: opj(
  295. project.path,
  296. "annotations",
  297. r["set"],
  298. "converted",
  299. r["annotation_filename"],
  300. ),
  301. axis=1,
  302. )
  303. missing_annotations = _annotations[~_annotations["path"].map(exists)]
  304. if len(missing_annotations):
  305. datalad.api.get(list(missing_annotations["path"].unique()))
  306. segments = am.get_segments(_annotations)
  307. segments["segment_onset"] -= segments["segment_onset"].min()
  308. segments = segments[segments["segment_onset"] >= 0]
  309. segments = segments[segments["segment_onset"] < duration]
  310. if len(segments) == 0:
  311. continue
  312. segments = segments[segments["speaker_type"].isin(["CHI", "OCH", "FEM", "MAL"])]
  313. rec = {}
  314. for i, speaker in enumerate(speakers):
  315. rec[f"algo1_{i}"] = len(segments[(segments["speaker_type"] == speaker)&(segments["set"]==args.algo1)])
  316. rec[f"algo2_{i}"] = len(segments[(segments["speaker_type"] == speaker)&(segments["set"]==args.algo2)])
  317. rec["recording"] = recording_filename
  318. rec["children"] = f"{corpus}_{child_id}"
  319. rec["corpus"] = basename(corpus)
  320. rec["age"] = age
  321. rec["siblings"] = siblings
  322. recs.append(rec)
  323. recs = pd.DataFrame(recs)
  324. return recs
  325. if __name__ == "__main__":
  326. recs = pd.concat([compile_recordings(corpus) for corpus in args.corpora])
  327. recs["children"] = recs["children"].astype("category").cat.codes.astype(int) + 1
  328. annotators = pd.read_csv(args.calibration)
  329. annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
  330. with mp.Pool(processes=args.chains * args.threads_per_chain) as pool:
  331. data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
  332. corpora = sorted(list(set(data["corpus"].unique()) | set(recs["corpus"].unique())))
  333. corpora_map = {
  334. corpus: i
  335. for i, corpus in enumerate(corpora)
  336. }
  337. print(corpora_map)
  338. duration = data["duration"].sum()
  339. clip_duration = data["duration"].values
  340. clip_age = data["clip_age"].values
  341. clip_rural = data["clip_rural"].values
  342. truth_total = np.transpose([data[f"truth_total_{i}"].values for i in range(4)])
  343. algo1_total = np.transpose([data[f"algo1_total_{i}"].values for i in range(4)])
  344. algo2_total = np.transpose([data[f"algo2_total_{i}"].values for i in range(4)])
  345. # speech rates at the child level
  346. annotators = annotators[~annotators["annotator"].str.startswith("eaf_2021")]
  347. speech_rates = pd.concat([
  348. rates(annotator) for annotator in annotators.to_dict(orient="records")
  349. ])
  350. speech_rates.reset_index(inplace=True)
  351. speech_rate_matrix = np.transpose(
  352. [speech_rates[f"speech_rate_{i}"].values for i in range(4)]
  353. )
  354. speech_rate_age = speech_rates["age"].values
  355. speech_rate_siblings = speech_rates["siblings"].values.astype(int)
  356. speech_rates.to_csv("rates.csv")
  357. data["corpus"] = data["corpus"].astype("category")
  358. confusion_data = {
  359. "n_clips": truth_total.shape[0],
  360. "n_classes": truth_total.shape[1],
  361. "n_groups": data[args.confusion_groups].nunique(),
  362. "n_corpora": len(corpora),
  363. "n_validation": 0,
  364. "group": 1 + data[args.confusion_groups].astype("category").cat.codes.values,
  365. "conf_corpus": 1 + data["corpus"].map(corpora_map).astype(int).values,
  366. "truth_total": truth_total.astype(int),
  367. "algo1_total": algo1_total.astype(int),
  368. "algo2_total": algo2_total.astype(int),
  369. "clip_duration": clip_duration,
  370. "clip_age": clip_age,
  371. "clip_rural": clip_rural,
  372. "clip_id": 1 + data["clip_id"].astype("category").cat.codes.values,
  373. "n_unique_clips": data["clip_id"].nunique(),
  374. "speech_rates": speech_rate_matrix.astype(int),
  375. "speech_rate_age": speech_rate_age,
  376. "speech_rate_siblings": speech_rate_siblings,
  377. "group_corpus": (
  378. 1 + speech_rates["corpus"].map(corpora_map).astype(int).values
  379. ),
  380. "speech_rate_child": 1+speech_rates["child"].astype("category").cat.codes.values,
  381. "n_speech_rate_children": speech_rates["child"].nunique(),
  382. "durations": speech_rates["duration"].values,
  383. "n_rates": len(speech_rates),
  384. }
  385. n_recs = len(recs)
  386. children_corpus = (
  387. recs.groupby("children").agg(corpus=("corpus", "first")).sort_index()
  388. )
  389. children_corpus = 1 + children_corpus.corpus.map(corpora_map).astype(int).values
  390. analysis_data = {
  391. "n_recs": n_recs,
  392. "n_children": len(recs["children"].unique()),
  393. "children": recs["children"],
  394. "vocs_algo1": np.transpose([recs[f"algo1_{i}"].values for i in range(4)]),
  395. "vocs_algo2": np.transpose([recs[f"algo2_{i}"].values for i in range(4)]),
  396. "age": recs["age"],
  397. "siblings": recs["siblings"].astype(int),
  398. "corpus": children_corpus,
  399. "recs_duration": args.duration,
  400. }
  401. data = {**analysis_data, **confusion_data}
  402. data["threads"] = args.threads_per_chain
  403. print(data)
  404. output = {}
  405. for model_name in args.models:
  406. output[model_name] = run_model(data, args.run, model_name)