metrics.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #!/usr/bin/env python3
  2. # Original file: https://gin.g-node.org/EL1000/metrics
  3. import argparse
  4. import datetime
  5. from functools import reduce
  6. import os
  7. import warnings
  8. import datalad.api
  9. import pandas as pd
  10. from ChildProject.projects import ChildProject
  11. from ChildProject.annotations import AnnotationManager
  12. from ChildProject.pipelines.metrics import LenaMetrics, AclewMetrics
  13. def date_is_valid(date, fmt):
  14. try:
  15. datetime.datetime.strptime(date, fmt)
  16. except:
  17. return False
  18. return True
  19. def compute_metrics(args):
  20. if len(args.experiments):
  21. experiments = args.experiments
  22. else:
  23. datasets = datalad.api.subdatasets(path='DATASETS/')
  24. experiments = [os.path.basename(dataset["path"]) for dataset in datasets]
  25. print(
  26. "pipeline '\033[1m{}\033[0m' will run on experiments '\033[1m{}\033[0m'".format(
  27. args.pipeline, ",".join(experiments)
  28. )
  29. )
  30. data = []
  31. columns = []
  32. for experiment in experiments:
  33. project = ChildProject(os.path.join("DATASETS", experiment), enforce_dtypes=True)
  34. am = AnnotationManager(project)
  35. if args.pipeline == "aclew":
  36. if "vtc" not in am.annotations["set"].tolist():
  37. print(f"skipping {experiment} (no VTC annotation)")
  38. continue
  39. metrics = AclewMetrics(
  40. project,
  41. vtc="vtc",
  42. alice="alice",
  43. vcm="vcm",
  44. by="session_id",
  45. threads=args.threads,
  46. ).extract()
  47. elif args.pipeline == "lena":
  48. metrics = LenaMetrics(
  49. project, set="its", types=["OLN"], by="session_id", threads=args.threads
  50. ).extract()
  51. elif args.pipeline == "children":
  52. data.append(project.children.assign(experiment=experiment))
  53. columns.append(project.children.columns)
  54. continue
  55. else:
  56. raise ValueError("undefined pipeline '{}'".format(args.pipeline))
  57. metrics = metrics.assign(experiment=experiment)
  58. if not len(metrics):
  59. print(
  60. "warning: experiment '{}' did not return any metrics for pipeline '{}'".format(
  61. experiment, args.pipeline
  62. )
  63. )
  64. continue
  65. # compute ages
  66. metrics = metrics.merge(
  67. project.recordings[["session_id", "date_iso"]].drop_duplicates(
  68. "session_id", keep="first"
  69. ),
  70. how="left",
  71. left_on="session_id",
  72. right_on="session_id",
  73. )
  74. metrics = metrics.merge(
  75. project.children[["child_id", "child_dob"]],
  76. how="left",
  77. left_on="child_id",
  78. right_on="child_id",
  79. )
  80. metrics["age"] = (
  81. metrics[["date_iso", "child_dob"]]
  82. .apply(
  83. lambda r: (
  84. datetime.datetime.strptime(r["date_iso"], "%Y-%m-%d")
  85. - datetime.datetime.strptime(r["child_dob"], "%Y-%m-%d")
  86. )
  87. if (
  88. date_is_valid(r["child_dob"], "%Y-%m-%d")
  89. and date_is_valid(r["date_iso"], "%Y-%m-%d")
  90. )
  91. else None,
  92. axis=1,
  93. )
  94. .apply(lambda dt: dt.days / (365.25 / 12) if dt else None)
  95. .apply(lambda a: int(a) if not pd.isnull(a) else "NA")
  96. )
  97. recordings = project.recordings
  98. if "session_offset" not in recordings.columns:
  99. recordings = recordings.assign(session_offset=0)
  100. # compute missing audio
  101. metrics = metrics.merge(
  102. recordings[["session_id", "session_offset", "duration"]]
  103. .sort_values("session_offset")
  104. .groupby("session_id")
  105. .agg(
  106. last_offset=("session_offset", lambda x: x.iloc[-1]),
  107. last_duration=("duration", lambda x: x.iloc[-1]),
  108. total=("duration", "sum"),
  109. )
  110. .reset_index(),
  111. how="left",
  112. left_on="session_id",
  113. right_on="session_id",
  114. )
  115. metrics["missing_audio"] = (
  116. metrics["last_offset"] + metrics["last_duration"] - metrics["total"]
  117. )
  118. metrics.drop(columns=["last_offset", "last_duration", "total"], inplace=True)
  119. data.append(metrics)
  120. if args.pipeline != "children":
  121. pd.concat(data).set_index(["experiment", "session_id", "child_id"]).to_csv(
  122. args.output
  123. )
  124. else:
  125. data = pd.concat(data)
  126. columns = reduce(lambda x, y: x & set(y), columns, columns[0]) | {
  127. "normative",
  128. "ses",
  129. }
  130. data = data[columns]
  131. data.set_index("child_id").to_csv(args.output)
  132. def main(args):
  133. compute_metrics(args)
  134. def _parse_args(argv):
  135. warnings.filterwarnings("ignore")
  136. parser = argparse.ArgumentParser(description="compute metrics")
  137. parser.add_argument(
  138. "pipeline", help="pipeline to run", choices=["aclew", "lena", "children", "period"]
  139. )
  140. parser.add_argument("output", help="output file")
  141. parser.add_argument("--experiments", nargs="+", default=[])
  142. parser.add_argument("--threads", default=0, type=int)
  143. parser.add_argument("--period", default=None, type=str)
  144. args = parser.parse_args(argv)
  145. return args
  146. if __name__ == '__main__':
  147. import sys
  148. pgrm_name, argv = sys.argv[0], sys.argv[1:]
  149. args = _parse_args(argv)
  150. main(**args)