utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os, json, h5py
  2. import numpy as np
  3. import xml.etree.ElementTree as ET
  4. from datetime import datetime
  5. def session_to_numbers(session_name):
  6. animal, e_type, e_date, e_time = session_name.split('_')
  7. dt = datetime.strptime('%s_%s' % (e_date, e_time), '%Y-%m-%d_%H-%M-%S')
  8. animal_code = int(animal)
  9. session_code = (dt.year-2000)*10**6 + dt.month*10**4 + dt.day*10**2 + dt.hour
  10. return animal_code, session_code
  11. def guess_filebase(sessionpath):
  12. return os.path.basename(os.path.normpath(sessionpath))
  13. def get_sessions_list(path_to_sessions_folder, animal):
  14. def convert(func, *args):
  15. try:
  16. return func(*args)
  17. except ValueError:
  18. return False
  19. is_dir = lambda x: os.path.isdir(os.path.join(path_to_sessions_folder, x))
  20. has_parts = lambda x: len(x.split('_')) > 2
  21. starts_by_animal = lambda x: x.split('_')[0] == animal
  22. has_timestamp = lambda x: convert(datetime.strptime, '%s_%s' % (x.split('_')[-2], x.split('_')[-1]), '%Y-%m-%d_%H-%M-%S')
  23. return sorted([x for x in os.listdir(path_to_sessions_folder) if
  24. is_dir(x) and has_parts(x) and starts_by_animal(x) and has_timestamp(x)])
  25. def get_sampling_rate(sessionpath):
  26. filebase = os.path.basename(sessionpath)
  27. xml_path = os.path.join(sessionpath, filebase + '.xml')
  28. if not os.path.exists(xml_path):
  29. return None
  30. root = ET.parse(xml_path).getroot()
  31. sampling_rate = root.findall('acquisitionSystem')[0].findall('samplingRate')[0]
  32. return int(sampling_rate.text)
  33. def unit_number_for_electrode(sessionpath, electrode_idx):
  34. filebase = ''
  35. try:
  36. filebase = guess_filebase(sessionpath)
  37. except ValueError:
  38. return 0 # no units on this electrode
  39. clu_file = os.path.join(sessionpath, '.'.join([filebase, 'clu', str(electrode_idx)]))
  40. cluster_map = np.loadtxt(clu_file, dtype=np.uint16)
  41. return len(np.unique(cluster_map)) - 1 # 1st cluster is noise
  42. def unit_number_for_session(sessionpath):
  43. electrode_idxs = [x.split('.')[2] for x in os.listdir(sessionpath) if x.find('.clu.') > -1]
  44. idxs = []
  45. for el in electrode_idxs:
  46. try:
  47. elem = int(el)
  48. idxs.append(elem)
  49. except ValueError:
  50. pass
  51. unit_counts = [unit_number_for_electrode(sessionpath, el_idx) for el_idx in np.unique(idxs)]
  52. return np.array(unit_counts).sum()
  53. def get_epochs(sessionpath):
  54. filebase = guess_filebase(sessionpath)
  55. h5name = os.path.join(sessionpath, filebase + '.h5')
  56. with h5py.File(h5name, 'r') as f:
  57. cfg = json.loads(f['processed'].attrs['parameters'])
  58. tp = np.array(cfg['experiment']['timepoints'])
  59. s_duration = cfg['experiment']['session_duration']
  60. tp = np.repeat(tp, 2)
  61. epochs = np.concatenate([np.array([0]), tp, np.array([s_duration])])
  62. return epochs.reshape(int(len(epochs)/2), 2)
  63. def cleaned_epochs(sessionpath):
  64. epochs = get_epochs(sessionpath)
  65. # if there are 5 epochs: session is continuous, remove speaker rotation periods
  66. if len(epochs) > 3:
  67. epochs = epochs[slice(0, 5, 2)]
  68. # add whole session as epoch, if not yet there
  69. if len(epochs) > 1:
  70. whole_session = np.array([[epochs[0][0], epochs[-1][1]]])
  71. return np.concatenate([epochs, whole_session])
  72. else:
  73. return epochs