123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508 |
- #!/usr/bin/env python
- import pandas as pd
- import numpy as np
- from ChildProject.projects import ChildProject
- from ChildProject.annotations import AnnotationManager
- from ChildProject.metrics import segments_to_grid, segments_to_annotation
- from ChildProject.pipelines.metrics import AclewMetrics
- from ChildProject.utils import TimeInterval
- from cmdstanpy import CmdStanModel
- import datetime
- import matplotlib
- from matplotlib import pyplot as plt
- matplotlib.use("pgf")
- matplotlib.rcParams.update(
- {
- "pgf.texsystem": "xelatex",
- "font.family": "serif",
- "font.serif": "Times New Roman",
- "text.usetex": True,
- "pgf.rcfonts": False,
- }
- )
- from collections import defaultdict
- import pickle
- import datalad.api
- from os.path import join as opj
- from os.path import basename, exists
- import multiprocessing as mp
- from pyannote.core import Annotation, Segment, Timeline
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--corpora", default=["input/bergelson"], nargs="+")
- parser.add_argument("--duration", type=int, help="duration in hours", default=8)
- parser.add_argument("--calibration", default="input/annotators.csv")
- parser.add_argument("--run")
- parser.add_argument("--algo1", default="vtc", help="algorithm1 annotation set")
- parser.add_argument("--algo2", default="its", help="algorithm2 annotation set")
- parser.add_argument("--chains", type=int, default=1)
- parser.add_argument("--warmup", type=int, default=250)
- parser.add_argument("--samples", type=int, default=1000)
- parser.add_argument("--threads-per-chain", type=int, default=4)
- parser.add_argument("--models", default=["causal"], nargs="+")
- parser.add_argument("--confusion-clips", choices=["recording", "clip"], default="clip")
- parser.add_argument("--confusion-groups", choices=["child", "recording"], default="recording")
- parser.add_argument("--time-scale", type=int, default=0)
- parser.add_argument("--require-clip-age", action="store_true", default=True)
- parser.add_argument("--adapt-delta", type=float, default=0.9)
- parser.add_argument("--max-treedepth", type=int, default=14)
- args = parser.parse_args()
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- def extrude(self, removed, mode: str = "intersection"):
- if isinstance(removed, Segment):
- removed = Timeline([removed])
- truncating_support = removed.gaps(support=self.extent())
- # loose for truncate means strict for crop and vice-versa
- if mode == "loose":
- mode = "strict"
- elif mode == "strict":
- mode = "loose"
- return self.crop(truncating_support, mode=mode)
- def children_siblings(corpus):
- siblings = pd.read_csv("input/siblings.csv")
- siblings["child_id"] = siblings["child_id"].astype(str)
- siblings = siblings[siblings["corpus"]==basename(corpus)].set_index("child_id")
- siblings = siblings["n_siblings"].to_dict()
- n = defaultdict(lambda: -1, **siblings)
- return n
- def compute_counts(parameters):
- corpus = parameters["corpus"]
- annotator = parameters["annotator"]
- rural = ((parameters["type"] == "rural")*1) if "type" in parameters else 0
- project = ChildProject(parameters["path"])
- am = AnnotationManager(project)
- am.read()
- project.recordings["age"] = project.compute_ages()
- intersection = AnnotationManager.intersection(am.annotations, [args.algo1, args.algo2, annotator])
- if len(intersection) == 0:
- print(f"No intersection between '{args.algo1}', '{args.algo2}' and '{annotator}' in {corpus}!")
- return pd.DataFrame()
- intersection["path"] = intersection.apply(
- lambda r: opj(
- project.path, "annotations", r["set"], "converted", r["annotation_filename"]
- ),
- axis=1,
- )
- missing_annotations = intersection[~intersection["path"].map(exists)]
- if len(missing_annotations):
- datalad.api.get(list(missing_annotations["path"].unique()))
- intersection = intersection.merge(
- project.recordings[["recording_filename", "child_id", "age"]], how="left"
- )
- intersection["child"] = corpus + "_" + intersection["child_id"].astype(str)
- intersection["recording"] = corpus + "_" + intersection["recording_filename"].astype(str)
- intersection["clip_id"] = (
- corpus
- + "_"
- + intersection["recording_filename"].str.cat(intersection["range_onset"].astype(str))
- )
- if args.time_scale:
- intersection["onset"] = intersection.apply(
- lambda r: np.arange(
- r["range_onset"], r["range_offset"], args.time_scale * 1000
- ),
- axis=1,
- )
- intersection = intersection.explode("onset")
- intersection["range_onset"] = intersection["onset"]
- intersection["range_offset"] = (
- intersection["range_onset"] + args.time_scale * 1000
- ).clip(upper=intersection["range_offset"])
- intersection["duration"] = (
- intersection["range_offset"] - intersection["range_onset"]
- )
- print(corpus, annotator, (intersection["duration"] / 1000 / 3).sum() / 3600)
- data = []
- # for child, ann in intersection.groupby("child"):
- groupby = {
- "recording": ["child_id", "recording"],
- "clip": ["child_id", "recording", "range_onset"],
- }[args.confusion_clips]
- for clip, ann in intersection.groupby(groupby):
- child = clip[0]
- recording = clip[1]
- # print(corpus, child)
- segments = am.get_collapsed_segments(ann)
- if "speaker_type" not in segments.columns:
- continue
- segments = segments[segments["speaker_type"].isin(speakers)]
- algo1 = {
- speaker: segments_to_annotation(
- segments[
- (segments["set"] == args.algo1) & (segments["speaker_type"] == speaker)
- ],
- "speaker_type",
- )
- # .support(collar=200)
- .get_timeline()
- for speaker in speakers
- }
- algo2 = {
- speaker: segments_to_annotation(
- segments[
- (segments["set"] == args.algo2) & (segments["speaker_type"] == speaker)
- ],
- "speaker_type",
- )
- # .support(collar=200)
- .get_timeline()
- for speaker in speakers
- }
- truth = {
- speaker: segments_to_annotation(
- segments[
- (segments["set"] == annotator)
- & (segments["speaker_type"] == speaker)
- ],
- "speaker_type",
- )
- # .support(collar=200)
- .get_timeline()
- for speaker in speakers
- }
- d = {}
- for i, speaker_A in enumerate(speakers):
- d[f"truth_total_{i}"] = len(truth[speaker_A])
- d[f"algo1_total_{i}"] = len(algo1[speaker_A])
- d[f"algo2_total_{i}"] = len(algo2[speaker_A])
- d["child"] = child
- d["recording"] = recording
- d["duration"] = ann["duration"].sum() / 3 / 1000 / 3600
- d["clip_id"] = ann["clip_id"].iloc[0]
- d["clip_age"] = ann["age"].iloc[0]
- data.append(d)
- data = pd.DataFrame(data).assign(
- corpus=corpus,
- clip_rural=rural
- )
- if args.require_clip_age:
- data.dropna(subset=["clip_age"], inplace=True)
- print(data)
-
- return data
- def rates(parameters):
- corpus = parameters["corpus"]
- annotator = parameters["annotator"]
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- project = ChildProject(parameters["path"])
- am = AnnotationManager(project)
- am.read()
- pipeline = AclewMetrics(
- project,
- vtc=annotator,
- alice=None,
- vcm=None,
- from_time="10:00:00",
- to_time="18:00:00",
- by="recording_filename",
- threads=args.chains
- )
- metrics = pipeline.extract()
- metrics = pd.DataFrame(metrics).assign(corpus=corpus, annotator=annotator)
- project.recordings["age"] = project.compute_ages()
- project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
- children_siblings(corpus)
- )
- metrics = metrics.merge(
- project.recordings[["recording_filename", "age", "siblings"]]
- )
- metrics["duration"] = metrics[f"duration_{annotator}"] / 1000 / 3600
- metrics = metrics[metrics["duration"] > 0.01]
- metrics["child"] = corpus + "_" + metrics["child_id"].astype(str)
- speakers = ["CHI", "OCH", "FEM", "MAL"]
- # metrics.dropna(subset={f"voc_{speaker.lower()}_ph" for speaker in speakers}&set(metrics.columns), inplace=True)
- for i, speaker in enumerate(speakers):
- # if f"voc_{speaker.lower()}_ph" not in metrics.columns:
- # metrics[f"speech_rate_{i}"] = pd.NA
- # else:
- metrics[f"speech_rate_{i}"] = (
- (metrics[f"voc_{speaker.lower()}_ph"] * (metrics["duration"]))
- .fillna(0)
- .astype(int)
- )
- return metrics
- def run_model(data, run, model_name):
- model = CmdStanModel(
- stan_file=f"code/models/{model_name}.stan",
- cpp_options={"STAN_THREADS": "TRUE"},
- )
- fit = model.sample(
- data=data,
- chains=args.chains,
- threads_per_chain=args.threads_per_chain,
- iter_sampling=args.samples,
- iter_warmup=args.warmup,
- step_size=0.1,
- max_treedepth=args.max_treedepth,
- adapt_delta=args.adapt_delta,
- # save_profile=True,
- show_console=True,
- )
- vars = fit.stan_variables()
- samples = {}
- for (k, v) in vars.items():
- samples[k] = v
- np.savez_compressed(f"output/aggregates_{run}_{model_name}.npz", **samples)
- samples = np.load(f"output/aggregates_{run}_{model_name}.npz")
- with open(f"output/aggregates_{run}_{model_name}.pickle", "wb") as f:
- pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
- print(fit.diagnose())
- return {"samples": samples}
- def compile_recordings(corpus):
- project = ChildProject(corpus)
- am = AnnotationManager(project)
- am.read()
- project.recordings["age"] = project.compute_ages()
- project.recordings["siblings"] = project.recordings.child_id.astype(str).map(
- children_siblings(corpus)
- )
- annotations = am.annotations[am.annotations["set"].isin([args.algo1, args.algo2])]
- annotations = annotations.merge(
- project.recordings,
- left_on="recording_filename",
- right_on="recording_filename",
- how="inner",
- )
- annotations = annotations.merge(
- project.children,
- left_on="child_id",
- right_on="child_id",
- how="inner"
- )
- if "normative" in annotations.columns:
- print("filtering out non-normative kids")
- annotations = annotations[annotations["normative"] != "N"]
- recs = []
- for recording_filename, _annotations in annotations.groupby("recording_filename"):
- _annotations = am.get_within_time_range(
- _annotations,
- TimeInterval(
- datetime.datetime(1900, 1, 1, 10, 0),
- datetime.datetime(1900, 1, 1, 10 + args.duration, 0),
- ),
- )
- if len(_annotations) == 0:
- continue
-
- child_id = _annotations["child_id"].max()
- age = _annotations["age"].max()
- siblings = _annotations["siblings"].max()
- _annotations["duration"] = _annotations["range_offset"]-_annotations["range_onset"]
- durations = _annotations.groupby("set")["duration"].sum()
- if (durations < args.duration * 3600 * 1000).any():
- continue
- duration = args.duration * 3600 * 1000
- _annotations["path"] = _annotations.apply(
- lambda r: opj(
- project.path,
- "annotations",
- r["set"],
- "converted",
- r["annotation_filename"],
- ),
- axis=1,
- )
- missing_annotations = _annotations[~_annotations["path"].map(exists)]
- if len(missing_annotations):
- datalad.api.get(list(missing_annotations["path"].unique()))
- segments = am.get_segments(_annotations)
- segments["segment_onset"] -= segments["segment_onset"].min()
- segments = segments[segments["segment_onset"] >= 0]
- segments = segments[segments["segment_onset"] < duration]
- if len(segments) == 0:
- continue
- segments = segments[segments["speaker_type"].isin(["CHI", "OCH", "FEM", "MAL"])]
- rec = {}
- for i, speaker in enumerate(speakers):
- rec[f"algo1_{i}"] = len(segments[(segments["speaker_type"] == speaker)&(segments["set"]==args.algo1)])
- rec[f"algo2_{i}"] = len(segments[(segments["speaker_type"] == speaker)&(segments["set"]==args.algo2)])
-
- rec["recording"] = recording_filename
- rec["children"] = f"{corpus}_{child_id}"
- rec["corpus"] = basename(corpus)
- rec["age"] = age
- rec["siblings"] = siblings
- recs.append(rec)
- recs = pd.DataFrame(recs)
- return recs
- if __name__ == "__main__":
- recs = pd.concat([compile_recordings(corpus) for corpus in args.corpora])
- recs["children"] = recs["children"].astype("category").cat.codes.astype(int) + 1
- annotators = pd.read_csv(args.calibration)
- annotators["path"] = annotators["corpus"].apply(lambda c: opj("input", c))
- with mp.Pool(processes=args.chains * args.threads_per_chain) as pool:
- data = pd.concat(pool.map(compute_counts, annotators.to_dict(orient="records")))
- corpora = sorted(list(set(data["corpus"].unique()) | set(recs["corpus"].unique())))
- corpora_map = {
- corpus: i
- for i, corpus in enumerate(corpora)
- }
- print(corpora_map)
- duration = data["duration"].sum()
- clip_duration = data["duration"].values
- clip_age = data["clip_age"].values
- clip_rural = data["clip_rural"].values
- truth_total = np.transpose([data[f"truth_total_{i}"].values for i in range(4)])
- algo1_total = np.transpose([data[f"algo1_total_{i}"].values for i in range(4)])
- algo2_total = np.transpose([data[f"algo2_total_{i}"].values for i in range(4)])
- # speech rates at the child level
- annotators = annotators[~annotators["annotator"].str.startswith("eaf_2021")]
- speech_rates = pd.concat([
- rates(annotator) for annotator in annotators.to_dict(orient="records")
- ])
- speech_rates.reset_index(inplace=True)
- speech_rate_matrix = np.transpose(
- [speech_rates[f"speech_rate_{i}"].values for i in range(4)]
- )
- speech_rate_age = speech_rates["age"].values
- speech_rate_siblings = speech_rates["siblings"].values.astype(int)
- speech_rates.to_csv("rates.csv")
- data["corpus"] = data["corpus"].astype("category")
-
- confusion_data = {
- "n_clips": truth_total.shape[0],
- "n_classes": truth_total.shape[1],
- "n_groups": data[args.confusion_groups].nunique(),
- "n_corpora": len(corpora),
- "n_validation": 0,
- "group": 1 + data[args.confusion_groups].astype("category").cat.codes.values,
- "conf_corpus": 1 + data["corpus"].map(corpora_map).astype(int).values,
- "truth_total": truth_total.astype(int),
- "algo1_total": algo1_total.astype(int),
- "algo2_total": algo2_total.astype(int),
- "clip_duration": clip_duration,
- "clip_age": clip_age,
- "clip_rural": clip_rural,
- "clip_id": 1 + data["clip_id"].astype("category").cat.codes.values,
- "n_unique_clips": data["clip_id"].nunique(),
- "speech_rates": speech_rate_matrix.astype(int),
- "speech_rate_age": speech_rate_age,
- "speech_rate_siblings": speech_rate_siblings,
- "group_corpus": (
- 1 + speech_rates["corpus"].map(corpora_map).astype(int).values
- ),
- "speech_rate_child": 1+speech_rates["child"].astype("category").cat.codes.values,
- "n_speech_rate_children": speech_rates["child"].nunique(),
- "durations": speech_rates["duration"].values,
- "n_rates": len(speech_rates),
- }
- n_recs = len(recs)
- children_corpus = (
- recs.groupby("children").agg(corpus=("corpus", "first")).sort_index()
- )
- children_corpus = 1 + children_corpus.corpus.map(corpora_map).astype(int).values
- analysis_data = {
- "n_recs": n_recs,
- "n_children": len(recs["children"].unique()),
- "children": recs["children"],
- "vocs_algo1": np.transpose([recs[f"algo1_{i}"].values for i in range(4)]),
- "vocs_algo2": np.transpose([recs[f"algo2_{i}"].values for i in range(4)]),
- "age": recs["age"],
- "siblings": recs["siblings"].astype(int),
- "corpus": children_corpus,
- "recs_duration": args.duration,
- }
- data = {**analysis_data, **confusion_data}
- data["threads"] = args.threads_per_chain
- print(data)
- output = {}
- for model_name in args.models:
- output[model_name] = run_model(data, args.run, model_name)
|