load_data.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #!/user/bin/env python
  2. # coding=utf-8
  3. """
  4. @author: yannansu
  5. @created at: 23.09.21 16:07
  6. This module loads and reads data from `raw_data/subject_idx.yaml`
  7. Example usage:
  8. LD = LoadData('test', data_path='pilot_data')
  9. activ_df = LD.read_activity()
  10. df = LD.read_data()
  11. """
  12. import numpy as np
  13. import pandas as pd
  14. import datetime
  15. import os
  16. import yaml
  17. from data_analysis.yml2dict import yml2dict
  18. class LoadData:
  19. def __init__(self, sub, data_path=None,
  20. sel_cfg=None, sel_par=None,
  21. sel_ses=None, rm_ses=None, start_date=None,
  22. sel_ses_idx=None):
  23. """
  24. Select and load data.
  25. :param sub: subject name
  26. :param data_path: data repository, default: 'data'
  27. :param sel_cfg: selected config keywords
  28. :param sel_par: selected param keywords
  29. :param sel_ses: selected session keywords
  30. :param rm_ses: removed session keywords
  31. :param start_date: selected starting date, e.g. '20210000
  32. :param sel_ses_idx: selected session index
  33. """
  34. self.sub = sub
  35. self.data_path = data_path
  36. if self.data_path is None:
  37. self.data_path = 'data'
  38. self.sel_cfg = sel_cfg
  39. self.sel_par = sel_par
  40. self.sel_ses = sel_ses
  41. self.rm_ses = rm_ses
  42. self.start_date = start_date
  43. self.sel_ses_idx = sel_ses_idx
  44. def read_activity(self):
  45. """
  46. Read subject activity log with given selectors.
  47. :return: a dataframe listing a summary of the selected session
  48. """
  49. # Read xrl file line by line
  50. activ_log = yml2dict(os.path.join(self.data_path, self.sub, self.sub + '.yaml'))
  51. activ_df = pd.DataFrame(activ_log).T
  52. # Select finished blocks
  53. activ_df = activ_df[activ_df.status == 'completed']
  54. # Filter data by input selector
  55. if self.sel_cfg is not None:
  56. cfg_pattern = '|'.join(self.sel_cfg)
  57. activ_df = activ_df[activ_df.cfg_file.str.contains(cfg_pattern)]
  58. if self.sel_par is not None:
  59. par_pattern = '|'.join(self.sel_par)
  60. activ_df = activ_df[activ_df.par_file.str.contains(par_pattern)]
  61. # Restrict specific sessions by date or time
  62. if self.sel_ses is not None:
  63. activ_df = activ_df.filter(like=self.sel_ses, axis=0)
  64. if self.rm_ses is not None:
  65. activ_df = activ_df.drop(self.rm_ses, axis=0)
  66. if self.start_date is not None:
  67. activ_df = activ_df[activ_df.index >= self.start_date]
  68. if self.sel_ses_idx is not None:
  69. activ_df = activ_df.groupby('par_file').nth(self.sel_ses_idx)
  70. activ_df = activ_df.reset_index()
  71. if 'index' in activ_df.columns:
  72. activ_df = activ_df.drop(columns=['index'])
  73. return activ_df
  74. def read_data(self, save_csv=None):
  75. """
  76. Read data from selected session.
  77. :param save_csv: '*.csv', save data in csv 'data/subject_xx/*.csv' if not None
  78. :return: a dataframe of trial-based data of all selected sessions
  79. """
  80. # Read activity log
  81. activ_df = self.read_activity()
  82. # Read data files
  83. yml_list = activ_df.data_file.to_list()
  84. df_list = []
  85. # Separate trials and save to dataframe
  86. for block_idx, yml in zip(activ_df.block_idx, yml_list):
  87. yml_data = pd.DataFrame({k: v for k, v in yml2dict(yml).items()
  88. if (k.startswith('trial') and any(c.isdigit() for c in k))
  89. }).T
  90. yml_data['block_index'] = block_idx
  91. yml_data['sub'] = self.sub
  92. yml_data = yml_data.reset_index(drop=True)
  93. df_list.append(yml_data)
  94. df = pd.concat(df_list, ignore_index=True)
  95. if save_csv is not None:
  96. df.to_csv(os.path.join(self.data_path, self.sub, save_csv))
  97. return df
  98. """
  99. # Test
  100. LD = LoadData('test', data_path='pilot_data')
  101. activ_df = LD.read_activity()
  102. df = LD.read_data()
  103. """
  104. # LoadData('s1', data_path='pilot2_data', sel_par='config/pilot2_param.yaml', start_date='20211207').read_data()