generate_validation_tasks.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import pandas as pd
  2. import numpy as np
  3. import random
  4. import argparse
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument("annotation_id")
  7. parser.add_argument("n_tasks", type=int)
  8. parser.add_argument("--weight", action="store_true", default=False)
  9. args = parser.parse_args()
  10. categories = pd.read_csv("analyses/validation.csv").drop_duplicates(["topic", "pacs"])
  11. topics = pd.read_csv("output/hep-ct-75-0.1-0.001-130000-20/descriptions.csv")["description"].tolist()
  12. n_topics = len(topics)
  13. n_top = 10
  14. if args.weight:
  15. probs = pd.read_parquet("output/hep-ct-75-0.1-0.001-130000-20/topics_0.parquet")[["probs"]]
  16. probs = np.stack(probs.probs.values)
  17. probs = np.mean(probs,axis=0)
  18. weighted = "weighted"
  19. else:
  20. probs = np.zeros(n_topics)+1
  21. weighted = "unweighted"
  22. tasks = []
  23. for i in range(args.n_tasks):
  24. n1 = random.choices(np.arange(n_topics), probs, k=1)[0]
  25. t1 = topics[n1]
  26. t2 = ""
  27. u = categories[categories["topic"] == t1].head(n_top)
  28. mixture = random.choice([True, False])
  29. if mixture:
  30. n2 = random.choices(np.delete(np.arange(n_topics),n1), np.delete(probs,n1), k=1)[0]
  31. t2 = topics[n2]
  32. u1 = u.sample(int(n_top/2))
  33. u2 = categories[(categories["topic"] == t2)].head(n_top)
  34. u2 = u2[~u2["description"].isin(u1["description"])].sample(int(n_top/2))
  35. u1 = u1["description"].tolist()
  36. u2 = u2["description"].tolist()
  37. else:
  38. u = u.sample(frac=1)
  39. u1 = u["description"][:int(n_top/2)].tolist()
  40. u2 = u["description"][int(n_top/2):int(n_top)].tolist()
  41. tasks.append({
  42. 'question': i,
  43. 'topic1': t1,
  44. 'topic2': t2,
  45. 'categories1': u1,
  46. "categories2": u2
  47. })
  48. tasks = pd.DataFrame(tasks)
  49. tasks.to_csv(f"analyses/truth_{args.annotation_id}_{weighted}.csv")
  50. questions = tasks.copy().set_index("question")[["categories1", "categories2"]]
  51. questions["categories1"] = questions["categories1"].map(lambda l: "\n".join(l))
  52. questions["categories2"] = questions["categories2"].map(lambda l: "\n".join(l))
  53. questions["1 topic or 2 topics ?"] = ""
  54. questions.to_excel(f"analyses/questions_{args.annotation_id}_{weighted}.xlsx", merge_cells=True)