tests.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from ChildProject.projects import ChildProject
  2. from ChildProject.annotations import AnnotationManager
  3. from datetime import datetime
  4. import multiprocessing as mp
  5. import os
  6. import pandas as pd
  7. import re
  8. import sys
  9. class DatasetTester:
  10. def __init__(self, path: str, threads: int = 1):
  11. self.project = ChildProject(path)
  12. self.am = AnnotationManager(self.project)
  13. self.am.read()
  14. threads = int(threads)
  15. self.threads = threads if threads >= 1 else mp.cpu_count()
  16. def test_metadata(self):
  17. errors, warnings = self.project.validate(ignore_files = True)
  18. assert len(errors) == 0, 'project validation failed'
  19. def test_annotations(self):
  20. errors, warnings = self.am.validate(threads = self.threads)
  21. assert len(errors) == 0, 'annotations validation failed'
  22. def test_age(self):
  23. children = self.project.children.copy()
  24. recordings = self.project.recordings.copy()
  25. recordings = recordings.merge(
  26. children,
  27. how = 'left',
  28. left_on = 'child_id',
  29. right_on = 'child_id'
  30. )
  31. recordings['date_iso'] = recordings['date_iso'].apply(
  32. lambda s: datetime.strptime(s, '%Y-%m-%d')
  33. )
  34. recordings['child_dob'] = recordings['child_dob'].apply(
  35. lambda s: datetime.strptime(s, '%Y-%m-%d')
  36. )
  37. assert all(recordings.apply(
  38. lambda row: row['date_iso'] > row['child_dob'],
  39. axis = 1
  40. ))
  41. def test_ses(self):
  42. children = self.project.children.copy()
  43. children = children.dropna(subset = ['ses'])
  44. children = children[children['ses'] != 'NA']
  45. children['ses'] = children['ses'].astype(int)
  46. assert (children['ses'].values >= 1).all() and (children['ses'].values <= 5).all(), "ses should be >= 1 and <= 5"
  47. def test_language(self):
  48. children = self.project.children.copy()
  49. confidential_children_md_path = os.path.join(self.project.path, 'metadata/confidential/children.csv')
  50. if os.path.exists(confidential_children_md_path):
  51. children = children.merge(
  52. pd.read_csv(confidential_children_md_path),
  53. how = 'left',
  54. left_on = 'child_id',
  55. right_on = 'child_id'
  56. )
  57. if 'languages' in children.columns:
  58. children['languages'] = children['languages'].apply(lambda s: s.split(','))
  59. is_valid = children['languages'].apply(lambda l: all([s.isalpha for s in l]))
  60. assert(is_valid.all())
  61. elif 'language' in children.columns:
  62. assert(children['language'].str.isalpha().all())
  63. else:
  64. raise KeyError("neither 'languages' or 'language' present in the metadata")
  65. if 'monoling' in children.columns:
  66. assert children['monoling'].str.lower().isin(['y', 'n']).all(), "monoling not always y or n"
  67. else:
  68. raise KeyError("missing 'monoling' field")
  69. def test_sex(self):
  70. children = self.project.children.copy()
  71. children['child_sex'] = children['child_sex'].str.lower()
  72. assert children['child_sex'].isin(['m', 'f']).all(), "children sex not always m or f"
  73. def test_normativity(self):
  74. children = self.project.children.copy()
  75. if 'normative' in children.columns:
  76. assert children['normative'].str.lower().isin(['y', 'n']).all(), "normative not always y or n"
  77. else:
  78. raise KeyError("missing 'normative' field")