Lucas Gautheron 2 months ago
parent
commit
5325c2b15e
1 changed files with 4 additions and 3 deletions
  1. 4 3
      code/models/enumeration.py

+ 4 - 3
code/models/enumeration.py

@@ -57,6 +57,7 @@ parser.add_argument("--confusion-clips", choices=["recording", "clip"], default=
 parser.add_argument("--confusion-groups", choices=["child", "recording"], default="recording")
 parser.add_argument("--time-scale", type=int, default=15)
 parser.add_argument("--require-clip-age", action="store_true", default=True)
+parser.add_argument("--age-threshold", type=int, default=0)
 parser.add_argument("--adapt-delta", type=float, default=0.9)
 parser.add_argument("--max-treedepth", type=int, default=14)
 args = parser.parse_args()
@@ -426,12 +427,12 @@ if __name__ == "__main__":
         "truth_total": truth_total.astype(int),
         "algo_total": algo_total.astype(int),
         "clip_duration": clip_duration,
-        "clip_age": clip_age,
+        "clip_age": clip_age if args.age_threshold == 0 else np.minimum(clip_age, args.age_threshold),
         "clip_rural": clip_rural,
         "clip_id": 1 + data["clip_id"].astype("category").cat.codes.values,
         "n_unique_clips": data["clip_id"].nunique(),
         "speech_rates": speech_rate_matrix.astype(int),
-        "speech_rate_age": speech_rate_age,
+        "speech_rate_age": speech_rate_age if args.age_threshold == 0 else np.minimum(speech_rate_age, args.age_threshold),
         "speech_rate_siblings": speech_rate_siblings,
         "group_corpus": (
             1 + speech_rates["corpus_code"].astype(int).values
@@ -454,7 +455,7 @@ if __name__ == "__main__":
         "n_children": len(recs["children"].unique()),
         "children": recs["children"],
         "vocs": np.transpose([recs[f"algo_{i}"].values for i in range(4)]),
-        "age": recs["age"],
+        "age": recs["age"] if args.age_threshold == 0 else np.minimum(recs["age"], args.age_threshold),
         "siblings": recs["siblings"].astype(int),
         "corpus": children_corpus,
         "recs_duration": args.duration,