helper.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. #
  2. # MIT License
  3. #
  4. # Copyright (c) 2019 Keisuke Sehara
  5. #
  6. # Permission is hereby granted, free of charge, to any person obtaining a copy
  7. # of this software and associated documentation files (the "Software"), to deal
  8. # in the Software without restriction, including without limitation the rights
  9. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  10. # copies of the Software, and to permit persons to whom the Software is
  11. # furnished to do so, subject to the following conditions:
  12. #
  13. # The above copyright notice and this permission notice shall be included in all
  14. # copies or substantial portions of the Software.
  15. #
  16. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  17. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  18. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  19. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  20. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  21. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. #
  24. import sys as _sys
  25. import json as _js
  26. import re as _re
  27. from pathlib import Path as _Path
  28. import collections as _cl
  29. from warnings import warn as _warn
  30. import numpy as _np
  31. import pandas as _pd
  32. DATASETS_METADATA_FILE = 'DATASETS.json'
  33. HOW_TO_USE = f"""
  34. ------
  35. This 'helper.py' is written to work at the **root repository directory**.
  36. 1. please make sure that the directory structure of the datasets remain unchanged
  37. (you can miss data files, though).
  38. 2. please reposition this file inside the root directory (where you can find
  39. 'REPOSITORY.json').
  40. 3. change the current directory to the root repository directory.
  41. 4. from a Python session, run `import helper`.
  42. """
  43. SESSION_PATTERN = _re.compile(r'([a-zA-Z]+)([0-9]{4}-[0-9]{2}-[0-9]{2})-([0-9]+)')
  44. SUBDOMAIN_PATTERN = _re.compile(r'-([a-zA-Z0-9-]+)$')
  45. RUN_PATTERN = _re.compile(r'_run([0-9]+)_')
  46. rootdir = _Path(__file__).parent
  47. datasetdir = rootdir / "datasets"
  48. def __read_root_metadata(datasetdir):
  49. rootdir = _Path(datasetdir)
  50. if not datasetdir.is_dir():
  51. raise RuntimeError(f"not a directory: {rootdir}")
  52. metadata_file = datasetdir / DATASETS_METADATA_FILE
  53. if not metadata_file.is_file():
  54. raise RuntimeError(f"not a file: {metadata_file}")
  55. with open(metadata_file, 'r') as src:
  56. return _js.load(src, object_hook=_cl.OrderedDict)
  57. def __read_csv_metadata(filename):
  58. metadata_file = datasetdir / filename
  59. if not metadata_file.is_file():
  60. print(f"***cannot read from: {filename}", file=_sys.stderr)
  61. return None
  62. return _pd.read_csv(str(metadata_file))
  63. def __errormsg(msg):
  64. print(f"***{msg} {HOW_TO_USE}", file=_sys.stderr)
  65. root_metadata = None
  66. try:
  67. root_metadata = __read_root_metadata(rootdir)
  68. except RuntimeError as e:
  69. __errormsg(f"failed to read from '{DATASETS_METADATA_FILE}' ({e})")
  70. subjects_metadata = __read_csv_metadata("SUBJECTS.csv")
  71. sessions_metadata = __read_csv_metadata("SESSIONS.csv")
  72. session_params = ('session_name', 'session_type', 'date', 'session_index')
  73. base_params = tuple(set(('dataset', 'subject', 'domain', 'file', 'subdomain') \
  74. + session_params))
  75. parameters = tuple(set(base_params \
  76. + tuple(subjects_metadata.columns) \
  77. + tuple(sessions_metadata.columns)))
  78. def _update_with_subject(orig_context, subject_name):
  79. row = subjects_metadata.loc[subjects_metadata.name == subject_name,:]
  80. if row.shape[0] == 0:
  81. raise RuntimeError(f"subject not found in metadata: {subject_name}")
  82. cxt = row.iloc[0].to_dict()
  83. del cxt['name']
  84. cxt['subject'] = subject_name
  85. for key, val in orig_context.items():
  86. cxt[key] = val
  87. return cxt
  88. def _update_with_session(orig_context, session_name, session_type, date, session_index):
  89. matches = _np.array(sessions_metadata.subject == orig_context['subject']) \
  90. * _np.array(sessions_metadata.date == date) \
  91. * (_np.array(sessions_metadata['index'], dtype=int) == int(session_index))
  92. row = sessions_metadata.loc[matches,:]
  93. if row.shape[0] == 0:
  94. raise RuntimeError(f"session not found in metadata: {session_name}")
  95. cxt = row.iloc[0].to_dict()
  96. for key in ('subject', 'date', 'index'):
  97. del cxt[key]
  98. cxt['session_name'] = session_name
  99. cxt['session_type'] = session_type
  100. cxt['date'] = date
  101. cxt['session_index'] = session_index
  102. for key, val in orig_context.items():
  103. cxt[key] = val
  104. return cxt
  105. def describe_datasets(indent=2):
  106. if root_metadata is None:
  107. return __errormsg("metadata has not been initialized properly.")
  108. if isinstance(indent, int):
  109. indent = ' '*indent
  110. if len(root_metadata) > 0:
  111. print("Available datasets")
  112. for ds_name, ds_desc in root_metadata.items():
  113. print(f"--------------------\n\ndataset '{ds_name}':")
  114. desc = ds_desc.get('description', None)
  115. if desc:
  116. print(f"{indent*1}(description)")
  117. print(f"{indent*2}{desc}")
  118. domains = ds_desc.get("domains", {})
  119. if len(domains) > 0:
  120. print(f"{indent*1}(domains)")
  121. for key, dom_desc in domains.items():
  122. suffix = dom_desc.get('suffix', '')
  123. if len(suffix.strip()) == 0:
  124. suffix = 'no suffix'
  125. desc = dom_desc.get('description', '(no description)')
  126. print(f"{indent*2}- domain '{key}' ({suffix})")
  127. print(f"{indent*3}{desc}")
  128. else:
  129. print(f"{indent*1}(no available domains)")
  130. else:
  131. print("***no datasets available in this directory!", file=_sys.stderr)
  132. class dataspec(_cl.namedtuple('_dataspec', ('context', 'data'))):
  133. def __getattr__(self, name):
  134. val = super().__getattr__(name)
  135. if val:
  136. return val
  137. if name in self.context.keys():
  138. return self.context[name]
  139. def convert_data(self, datafunc):
  140. return self.__class__(self.context, datafunc(self.data))
  141. class predicate:
  142. _retrievable = ('datasets', 'subjects', 'domains', 'files', 'subdomains') \
  143. + ('session_names', 'session_types', 'dates', 'session_indices')
  144. def __init__(self):
  145. self.__cached = {}
  146. def __getattr__(self, name):
  147. if name in parameters:
  148. if name == 'file':
  149. raise NameError("use 'files' to retrieve file paths")
  150. return parameter(self, name)
  151. elif name == 'subdomain':
  152. raise ValueError("use '<context>.has_subdomain(<subdom>)' expression for restricting to a subdomain")
  153. elif name in self._retrievable:
  154. return self.retrieve(name)
  155. else:
  156. raise AttributeError(name)
  157. def get_datasets(self, as_spec=True, recalculate=False):
  158. if ('datasets' not in self.__cached.keys()) or (recalculate == True):
  159. dss = []
  160. for ds_name in root_metadata.keys():
  161. spec = dataspec(dict(dataset=ds_name), datasetdir / ds_name)
  162. if spec.data.is_dir():
  163. if self.__validate__('dataset', spec.context):
  164. # print(f"adding dataset: {ds_name}")
  165. dss.append(spec)
  166. self.__cached['datasets'] = dss
  167. self.__cached['dataset_names'] = [item.context['dataset'] for item in dss]
  168. if as_spec == True:
  169. return tuple(self.__cached['datasets'])
  170. else:
  171. return tuple(self.__cached['dataset_names'])
  172. def get_subjects(self, as_spec=True, recalculate=False):
  173. if ('subjects' not in self.__cached.keys()) or (recalculate == True):
  174. subs = []
  175. for ds in self.get_datasets(as_spec=True, recalculate=recalculate):
  176. for child in ds.data.iterdir():
  177. if not child.is_dir():
  178. continue
  179. cxt = _update_with_subject(ds.context, child.name)
  180. if self.__validate__('subject', cxt):
  181. # print(f"adding: {ds.context['dataset']}/{child.name}")
  182. spec = dataspec(cxt, child)
  183. subs.append(spec)
  184. self.__cached['subjects'] = subs
  185. self.__cached['subject_names'] = sorted(set(item.context['subject'] for item in subs))
  186. if as_spec == True:
  187. return tuple(self.__cached['subjects'])
  188. else:
  189. return tuple(self.__cached['subject_names'])
  190. def get_session_names(self, as_spec=True, recalculate=False):
  191. return self.get_sessions_impl(mode='session_names', as_spec=as_spec, recalculate=recalculate)
  192. def get_session_types(self, as_spec=True, recalculate=False):
  193. return self.get_sessions_impl(mode='session_types', as_spec=as_spec, recalculate=recalculate)
  194. def get_dates(self, as_spec=True, recalculate=False):
  195. return self.get_sessions_impl(mode='dates', as_spec=as_spec, recalculate=recalculate)
  196. def get_session_indices(self, as_spec=True, recalculate=False):
  197. return self.get_sessions_impl(mode='session_indices', as_spec=as_spec, recalculate=recalculate)
  198. def get_sessions_impl(self, mode='dates', as_spec=True, recalculate=False):
  199. if ('sessions' not in self.__cached.keys()) or (recalculate == True):
  200. sessions = []
  201. for sub in self.get_subjects(as_spec=True, recalculate=recalculate):
  202. for child in sub.data.iterdir():
  203. if not child.is_dir():
  204. continue
  205. # print(f"child: {child.name}")
  206. is_session = SESSION_PATTERN.search(child.name)
  207. if not is_session:
  208. continue
  209. stype = is_session.group(1)
  210. date = is_session.group(2)
  211. idx = is_session.group(3)
  212. sname = f"{stype}{date}-{idx}"
  213. cxt = _update_with_session(sub.context, sname, stype, date, idx)
  214. # print(f"session: name={sname}; type={stype}; date={date}; idx={idx}")
  215. if self.__validate__('session_name', cxt):
  216. # print(f"adding: {sub.context['dataset']}/{sub.context['subject']}/{date}")
  217. spec = dataspec(cxt, child)
  218. sessions.append(spec)
  219. self.__cached['sessions'] = sessions
  220. self.__cached['dates'] = sorted(set(item.context['date'] for item in sessions))
  221. self.__cached['session_names'] = sorted(set(item.context['session_name'] for item in sessions))
  222. self.__cached['session_types'] = sorted(set(item.context['session_type'] for item in sessions))
  223. self.__cached['session_indices'] = sorted(set(item.context['session_index'] for item in sessions))
  224. if as_spec == True:
  225. return self.__cached['sessions']
  226. else:
  227. return tuple(self.__cached[mode])
  228. def get_domains(self, as_spec=True, recalculate=False):
  229. if ('domains' not in self.__cached.keys()) or (recalculate == True):
  230. doms = []
  231. for sessions in self.get_dates(as_spec=True, recalculate=recalculate):
  232. for child in sessions.data.iterdir():
  233. if not child.is_dir():
  234. continue
  235. dom = child.name
  236. cxt = sessions.context.copy()
  237. cxt['domain'] = dom
  238. # print(f"domain={child.name}")
  239. if self.__validate__('domain', cxt):
  240. # print(f"adding: {date.context['dataset']}/{date.context['subject']}/{date.context['date']}/{dom}")
  241. spec = dataspec(cxt, child)
  242. doms.append(spec)
  243. self.__cached['domains'] = doms
  244. self.__cached['domain_names'] = sorted(set(item.context['domain'] for item in doms))
  245. if as_spec == True:
  246. return tuple(self.__cached['domains'])
  247. else:
  248. return tuple(self.__cached['domain_names'])
  249. def get_files(self, as_spec=True, recalculate=False):
  250. return self.get_subdomains_impl(mode='files', as_spec=as_spec, recalculate=recalculate)
  251. def get_subdomains(self, as_spec=True, recalculate=False):
  252. return self.get_subdomains_impl(mode='subdomains', as_spec=as_spec, recalculate=recalculate)
  253. def get_subdomains_impl(self, mode='subdomains', as_spec=True, recalculate=False):
  254. if ('files' not in self.__cached.keys()) or (recalculate == True):
  255. files = []
  256. for dom in self.get_domains(as_spec=True, recalculate=recalculate):
  257. for child in dom.data.iterdir():
  258. # except for dot files
  259. if child.name.startswith('.'):
  260. continue
  261. has_subdomain = SUBDOMAIN_PATTERN.search(child.stem)
  262. if has_subdomain:
  263. subdomains = tuple(has_subdomain.group(1).split('-'))
  264. else:
  265. subdomains = ()
  266. has_run = RUN_PATTERN.search(child.name)
  267. cxt = dom.context.copy()
  268. cxt['subdomains'] = subdomains
  269. if has_run:
  270. cxt['run'] = int(has_run.group(1))
  271. # print(f"name={child.name}; subdomains={subdomains}")
  272. if self.__validate__('subdomain', cxt):
  273. spec = dataspec(cxt, child)
  274. files.append(spec)
  275. self.__cached['files'] = files
  276. self.__cached['file_paths'] = sorted(str(item.data) for item in files)
  277. self.__cached['subdomains'] = sorted(set(item.context['subdomains'] for item in files))
  278. if as_spec == True:
  279. return tuple(self.__cached['files'])
  280. elif mode == 'subdomains':
  281. return tuple(self.__cached['subdomains'])
  282. else:
  283. return tuple(self.__cached['file_paths'])
  284. def retrieve(self, param, recalculate=False):
  285. if root_metadata is None:
  286. return __errormsg("metadata has not been initialized properly.")
  287. options = dict(as_spec=False, recalculate=recalculate)
  288. if param == 'datasets':
  289. return self.get_datasets(**options)
  290. elif param == 'subjects':
  291. return self.get_subjects(**options)
  292. elif param in ('dates', 'session_names', 'session_types', 'session_indices'):
  293. return self.get_sessions_impl(mode=param,**options)
  294. elif param == 'domains':
  295. return self.get_domains(**options)
  296. elif param in ('files', 'subdomains'):
  297. return self.get_subdomains_impl(mode=param,**options)
  298. else:
  299. raise ValueError(f"unknown object type for retieval: {param}")
  300. def __validate__(self, param, value):
  301. raise NotImplementedError(f"{self.__class__.__name__}.__validate__")
  302. def __join__(self, op, other):
  303. if not isinstance(other, predicate):
  304. raise ValueError(f"cannot join {other.__class__} (expected conditional or joined)")
  305. return joined(op, self, other)
  306. def __add__(self, other):
  307. return self.__join__('add', other)
  308. def __mul__(self, other):
  309. return self.__join__('mul', other)
  310. def has_subdomain(self, subdom):
  311. return conditional('has', parameter(self, 'subdomains'), subdom)
  312. class _datasets(predicate):
  313. """manages file retrieval from datasets."""
  314. def __init__(self):
  315. super().__init__()
  316. def __repr__(self):
  317. return '<any>'
  318. def __validate__(self, param, value):
  319. return True
  320. class parameter:
  321. """manages contexts."""
  322. def __init__(self, parent, name):
  323. self.__parent = parent
  324. self.__name = name
  325. def __getattr__(self, name):
  326. if name == 'name':
  327. return self.__name
  328. else:
  329. raise AttributeError(name)
  330. def __repr__(self):
  331. parent = repr(self.__parent)
  332. return f"{parent}.{self.__name}"
  333. def __cond__(self, op, name):
  334. if not isinstance(name, str):
  335. raise ValueError(f"cannot compare to {name.__class__} (expected a string)")
  336. return conditional(op, self, name)
  337. def __eq__(self, value):
  338. return self.__cond__('eq', value)
  339. def __ne__(self, value):
  340. return self.__cond__('ne', value)
  341. def __gt__(self, value):
  342. return self.__cond__('gt', value)
  343. def __lt__(self, value):
  344. return self.__cond__('lt', value)
  345. def __ge__(self, value):
  346. return self.__cond__('ge', value)
  347. def __le__(self, value):
  348. return self.__cond__('le', value)
  349. def __validate__(self, param, context):
  350. return self.__parent.__validate__(param, context)
  351. class conditional(predicate):
  352. """manages conditions in contexts."""
  353. _opcodes = dict(eq='==',
  354. ne='!=',
  355. gt='>',
  356. ge='>=',
  357. lt='<',
  358. le='<=',
  359. has='has')
  360. _ops = {
  361. 'eq': (lambda _x, _v: _x == _v),
  362. 'ne': (lambda _x, _v: _x != _v),
  363. 'gt': (lambda _x, _v: _x > _v),
  364. 'ge': (lambda _x, _v: _x >= _v),
  365. 'lt': (lambda _x, _v: _x < _v),
  366. 'le': (lambda _x, _v: _x <= _v),
  367. 'has': (lambda _x, _v: _v in _x)
  368. }
  369. def __init__(self, op, param, value):
  370. super().__init__()
  371. self.__op = op
  372. self.__param = param
  373. self.__value = value
  374. def __getattr__(self, name):
  375. if name == 'opcode':
  376. opcode = self._opcodes.get(self.__op, None)
  377. if opcode:
  378. return opcode
  379. else:
  380. raise ValueError(f'unknown operation: {op}')
  381. else:
  382. return super().__getattr__(name)
  383. def __repr__(self):
  384. return f"({self.__param} {self.opcode} {repr(self.__value)})"
  385. def __validate__(self, param, context):
  386. if not self.__param.__validate__(param, context):
  387. return False
  388. elif self.__param.name not in base_params:
  389. # subject-, session- or trial- related variables
  390. if self.__param.name not in context.keys():
  391. return True
  392. elif (self.__param.name not in session_params) \
  393. or (param not in session_params):
  394. if param != self.__param.name:
  395. return True
  396. op = self._ops.get(self.__op, None)
  397. val = context.get(self.__param.name, None)
  398. if op:
  399. if val:
  400. return op(val, self.__value)
  401. else:
  402. return True
  403. else:
  404. raise ValueError(f'unknown operation: {op}')
  405. class joined(predicate):
  406. """joins two contexts."""
  407. def __init__(self, op, set1, set2):
  408. super().__init__()
  409. self.__op = op
  410. self.__set1 = set1
  411. self.__set2 = set2
  412. def __getattr__(self, name):
  413. if name == 'opcode':
  414. op = self.__op
  415. if op == 'mul':
  416. return '*'
  417. elif op == 'add':
  418. return '+'
  419. else:
  420. raise ValueError(f'unknown operation: {op}')
  421. else:
  422. return super().__getattr__(name)
  423. def __repr__(self):
  424. return f"({self.__set1} {self.opcode} {self.__set2})"
  425. def __validate__(self, param, value):
  426. cond1 = self.__set1.__validate__(param, value)
  427. cond2 = self.__set2.__validate__(param, value)
  428. op = self.__op
  429. if op == 'mul':
  430. return (cond1 and cond2)
  431. elif op == 'add':
  432. return (cond1 or cond2)
  433. else:
  434. raise ValueError(f'unknown operation: {op}')
  435. ### start script upon import
  436. datasets = _datasets()
  437. if __name__ != '__main__':
  438. if root_metadata is not None:
  439. describe_datasets()