adapters.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import numpy as np
  2. import os
  3. import h5py
  4. from scipy import signal
  5. from scipy.signal import butter, sosfilt
  6. COLORS = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:gray']
  7. EPOCH_NAMES = ('Original', 'Conflict', 'Control', 'All')
  8. class H5NAMES:
  9. inst_rate = {'name': 'inst_rate', 'dims': ['instantaneous firing rate at ~100Hz']}
  10. spike_times = {'name': 'spike_times', 'dims': ['spike times in seconds']}
  11. spike_idxs = {'name': 'spike_idxs', 'dims': ['indices to timeline when spikes occured']}
  12. mfr = {'name': 'mean_firing_rate', 'dims': ['epochs: original, conflict, control and all']}
  13. isi_cv = {'name': 'isi_coeff_var', 'dims': ['epochs: original, conflict, control and all']}
  14. isi_fano = {'name': 'isi_fano_factor', 'dims': ['epochs: original, conflict, control and all']}
  15. o_maps = {'name': 'occupancy_maps', 'dims': [
  16. 'epochs: original, conflict, control and all', 'X, bins', 'Y, bins'
  17. ]}
  18. f_maps = {'name': 'firing_rate_maps', 'dims': [
  19. 'epochs: original, conflict, control and all', 'X, bins', 'Y, bins'
  20. ]}
  21. sparsity = {'name': 'sparsity', 'dims': ['epochs: original, conflict, control and all']}
  22. selectivity = {'name': 'selectivity', 'dims': ['epochs: original, conflict, control and all']}
  23. spat_info = {'name': 'spatial_information', 'dims': ['epochs: original, conflict, control and all']}
  24. peak_FR = {'name': 'peak_firing_rate', 'dims': ['epochs: original, conflict, control and all']}
  25. f_patches = {'name': 'field_patches', 'dims': [
  26. 'epochs: original, conflict, control and all', 'X, bins', 'Y, bins'
  27. ]}
  28. f_sizes = {'name': 'field_sizes', 'dims': ['epochs: original, conflict, control and all']}
  29. f_COM = {'name': 'field_center_of_mass', 'dims': ['epochs: original, conflict, control and all', 'rho, phi in polar coords.']}
  30. pfr_center = {'name': 'field_center_of_firing', 'dims': ['epochs: original, conflict, control and all', 'rho, phi in polar coords.']}
  31. occ_info = {'name': 'occupancy_information', 'dims': ['epochs: original, conflict, control and all']}
  32. o_patches = {'name': 'occupancy_patches', 'dims': [
  33. 'epochs: original, conflict, control and all', 'X, bins', 'Y, bins'
  34. ]}
  35. o_COM = {'name': 'occupancy_center_of_mass', 'dims': ['epochs: original, conflict, control and all', 'rho, phi in polar coords.']}
  36. best_m_rot = {'name': 'best_match_rotation', 'dims': ['match between: A-B, B-C, A-C', 'correlation profile']}
  37. def load_clu_res(where):
  38. """
  39. Neurosuite files:
  40. dat - raw signal in binary (usually int16) format as a matrix channels x signal
  41. lfp - raw signal, downsampled (historically to 1250Hz)
  42. fet - list of feature vectors for every spike for a particular electrode
  43. spk - list of spike waveforms for every spike for a particular electrode, binary
  44. res - spike times in samples for all clusters (units) from a particular electrode
  45. clu - list of cluster (unit) numbers for each spike from 'res'
  46. Load spike times from 'clu' (clusters) and 'res' (spike times) files generated by KlustaKwik.
  47. :param where: path to the folder
  48. :param filebase: base name of the file (like 'foo' in 'foo.clu.3')
  49. :param index: index of the file (like '3' in 'foo.clu.3')
  50. :return: a dict in a form like {<clustered_unit_no>: <spike_times>, ...}
  51. """
  52. filebase = os.path.basename(where)
  53. clu_files = [f for f in os.listdir(where) if f.find('.clu.') > 0]
  54. if not len(clu_files) > 0:
  55. return {}
  56. idxs = [int(x.split('.')[2]) for x in clu_files] # electrode indexes
  57. all_units = {}
  58. for idx in idxs:
  59. clu_file = os.path.join(where, '.'.join([filebase, 'clu', str(idx)]))
  60. res_file = os.path.join(where, '.'.join([filebase, 'res', str(idx)]))
  61. if not os.path.isfile(clu_file) or not os.path.isfile(res_file):
  62. continue
  63. cluster_map = np.loadtxt(clu_file, dtype=np.uint16) # uint16 for clusters
  64. all_spikes = np.loadtxt(res_file, dtype=np.uint64) # uint64 for spike times
  65. cluster_map = cluster_map[1:] # remove the first element - number of clusters
  66. result = {}
  67. for cluster_no in np.unique(cluster_map)[1:]: # already sorted / remove 1st cluster - noise
  68. result[cluster_no] = all_spikes[cluster_map == cluster_no]
  69. all_units[idx] = result
  70. return all_units
  71. def create_dataset(h5name, where, descriptor, dataset):
  72. """
  73. h5name path to an HDF5 file
  74. where path inside the file
  75. descriptor H5NAMES style descriptor of the dataset
  76. dataset numpy array to store
  77. """
  78. with h5py.File(h5name, 'a') as f:
  79. target_group = f[where]
  80. if descriptor['name'] in target_group: # overwrite mode
  81. del target_group[descriptor['name']]
  82. ds = target_group.create_dataset(descriptor['name'], data=dataset)
  83. for i, dim in enumerate(descriptor['dims']):
  84. ds.attrs['dim%s' % i] = dim
  85. class DatProcessor:
  86. def __init__(self, dat_file):
  87. # read this from XML file
  88. self.s_rate = 30000
  89. self.ch_no = 64
  90. self.dat_file = dat_file
  91. @staticmethod
  92. def butter_bandpass_filter(data, lowcut, highcut, fs, order=6):
  93. def butter_bandpass(lowcut, highcut, fs, order=5):
  94. nyq = 0.5 * fs
  95. low = lowcut / nyq
  96. high = highcut / nyq
  97. sos = butter(order, [low, high], analog=False, btype='band', output='sos')
  98. return sos
  99. sos = butter_bandpass(lowcut, highcut, fs, order=order)
  100. y = sosfilt(sos, data)
  101. return y
  102. def read_block_from_dat(self, duration, offset):
  103. """
  104. duration in seconds
  105. offset in seconds
  106. """
  107. count = self.s_rate * self.ch_no * duration # number of values to read
  108. offset_in_bytes = offset * self.s_rate * self.ch_no * 2 # assuming int16 is 2 bytes
  109. block = np.fromfile(self.dat_file, dtype=np.int16, count=int(count), offset=int(offset_in_bytes))
  110. return block.reshape([int(self.s_rate * duration), self.ch_no])
  111. def get_single_channel(self, channel_no):
  112. size = os.path.getsize(self.dat_file)
  113. samples_no = size / (64 * 2)
  114. raw_signal = np.zeros(int(samples_no)) # length in time: samples_no / sample_rate
  115. offset = 0
  116. while offset < samples_no / self.s_rate - 1:
  117. block = self.read_block_from_dat(1, offset) # read in 1 sec blocks
  118. raw_signal[self.s_rate*offset:self.s_rate*(offset + 1)] = block[:, channel_no]
  119. offset += 1
  120. return raw_signal