target.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import numpy as np
  2. def build_tgt_matrix(tl, trials, aeps_events):
  3. # compute timeline / AEP indices of entrances / exist to the target
  4. tl_tgt_start_idxs = [] # timeline indices of first target pulse
  5. tl_tgt_end_idxs = [] # timeline indices of last target pulse
  6. aeps_tgt_start_idxs = [] # indices of first AEPs in target
  7. aeps_tgt_end_idxs = [] # indices of last AEPs in target
  8. for i in range(len(tl) - 1):
  9. if tl[i][6] < 2 and tl[i+1][6] == 2:
  10. nearest_aep_idx = np.abs(aeps_events[:, 0] - tl[i+1][0]).argmin()
  11. aeps_tgt_start_idxs.append(nearest_aep_idx)
  12. t_event = aeps_events[nearest_aep_idx][0]
  13. tl_tgt_start_idxs.append(np.abs(tl[:, 0] - t_event).argmin())
  14. if tl[i][6] == 2 and tl[i+1][6] < 2:
  15. nearest_aep_idx = np.abs(aeps_events[:, 0] - tl[i][0]).argmin()
  16. aeps_tgt_end_idxs.append(nearest_aep_idx)
  17. t_event = aeps_events[nearest_aep_idx][0]
  18. tl_tgt_end_idxs.append(np.abs(tl[:, 0] - t_event).argmin())
  19. # ignore first/last target if not ended
  20. if tl_tgt_start_idxs[-1] > tl_tgt_end_idxs[-1]:
  21. tl_tgt_start_idxs = tl_tgt_start_idxs[:-1]
  22. aeps_tgt_start_idxs = aeps_tgt_start_idxs[:-1]
  23. if tl_tgt_end_idxs[0] < tl_tgt_start_idxs[0]:
  24. tl_tgt_end_idxs = tl_tgt_end_idxs[1:]
  25. aeps_tgt_end_idxs = aeps_tgt_end_idxs[1:]
  26. tl_tgt_start_idxs = np.array(tl_tgt_start_idxs)
  27. tl_tgt_end_idxs = np.array(tl_tgt_end_idxs)
  28. # successful / missed
  29. tgt_results = np.zeros(len(tl_tgt_start_idxs))
  30. for idx_tl_success_end in trials[trials[:, 5] == 1][:, 1]:
  31. idx_succ = np.abs(tl_tgt_end_idxs - idx_tl_success_end).argmin()
  32. tgt_results[idx_succ] = 1
  33. # tl_idx_start, tl_idx_end, aep_idx_start, aer_idx_end, success / miss
  34. return np.column_stack([
  35. tl_tgt_start_idxs,
  36. tl_tgt_end_idxs,
  37. aeps_tgt_start_idxs,
  38. aeps_tgt_end_idxs,
  39. tgt_results
  40. ]).astype(np.int32)
  41. def build_silence_matrix(tl):
  42. idxs_silence_start, idxs_silence_end = [], []
  43. for i in range(len(tl) - 1):
  44. if tl[i][6] != 0 and tl[i+1][6] == 0: # silence start
  45. idxs_silence_start.append(i+1)
  46. elif tl[i][6] == 0 and tl[i+1][6] != 0: # silence end
  47. idxs_silence_end.append(i)
  48. if len(idxs_silence_start) > len(idxs_silence_end):
  49. idxs_silence_start = idxs_silence_start[:-1]
  50. idxs_silence_start = np.array(idxs_silence_start)
  51. idxs_silence_end = np.array(idxs_silence_end)
  52. return np.column_stack([idxs_silence_start, idxs_silence_end])
  53. def get_idxs_of_event_periods(tl, event_type):
  54. # event_type: -1, 0, 1, 2 (noise, silence, background, target)
  55. # returns: indices to timeline for periods of event_type
  56. idxs_events = np.where(tl[:, 6] == event_type)[0]
  57. idxs_to_idxs = np.where(np.diff(idxs_events) > 1)[0]
  58. # periods - indices to TL where was silent
  59. periods = np.zeros([len(idxs_to_idxs) + 1, 2])
  60. periods[0] = np.array([0, idxs_to_idxs[0]])
  61. periods[1:-1] = np.column_stack([idxs_to_idxs[:-1] + 1, idxs_to_idxs[1:]])
  62. periods[-1] = np.array([idxs_to_idxs[-1], len(idxs_events) - 1])
  63. periods = periods.astype(np.int32)
  64. # convert to TL indices
  65. return np.column_stack([idxs_events[periods[:, 0]], idxs_events[periods[:, 1]]])
  66. def build_silence_and_noise_events(tl, offset, latency, drift):
  67. # build hallucination pulses in silence and noise
  68. duration = tl[-1][0]
  69. # all pulses with drift
  70. #pulse_times = np.linspace(0, int(duration - latency), int(duration - latency)*4 + 1) + offset # if latency 0.25
  71. pulse_times = np.array([i*latency for i in range(int((duration - latency)/latency) + 10)]) + offset
  72. pulse_times = pulse_times[pulse_times < duration]
  73. pulse_times += np.arange(len(pulse_times)) * drift/len(pulse_times)
  74. # filter silence times only
  75. pulses_silence = []
  76. pulses_noise = []
  77. pulses_bgr = []
  78. pulses_tgt = []
  79. tl_idx = 0 # index of current pulse in the timeline
  80. for t_pulse in pulse_times:
  81. while tl[tl_idx][0] < t_pulse:
  82. tl_idx += 1
  83. if tl[tl_idx][6] == 0:
  84. pulses_silence.append(t_pulse)
  85. elif tl[tl_idx][6] == -1:
  86. pulses_noise.append(t_pulse)
  87. elif tl[tl_idx][6] == 1:
  88. pulses_bgr.append(t_pulse)
  89. elif tl[tl_idx][6] == 2:
  90. pulses_tgt.append(t_pulse)
  91. pulses_silence = np.array(pulses_silence)
  92. pulses_noise = np.array(pulses_noise)
  93. pulses_bgr = np.array(pulses_bgr)
  94. pulses_tgt = np.array(pulses_tgt)
  95. return pulses_silence, pulses_noise, pulses_bgr, pulses_tgt
  96. def build_event_mx(tl, offset, latency):
  97. drift_coeff = 0.055/2400
  98. duration = tl[-1][0]
  99. drift = duration * drift_coeff
  100. # all pulses with drift
  101. pulse_times = np.array([i*latency for i in range(int((duration - latency)/latency) + 10)]) + offset
  102. pulse_times += np.arange(len(pulse_times)) * drift/len(pulse_times)
  103. pulse_times = pulse_times[pulse_times < duration] # filter out if more pulses
  104. event_mx = np.zeros([len(pulse_times), 2])
  105. tl_idx = 0 # index of current pulse in the timeline
  106. for i, t_pulse in enumerate(pulse_times):
  107. while tl[tl_idx][0] < t_pulse:
  108. tl_idx += 1
  109. event_mx[i] = np.array([t_pulse, tl[tl_idx][6]])
  110. return event_mx[:-1]
  111. def get_spike_times_at(tl, s_times, periods, mode='sequence'):
  112. # 'sequence' - periods follow each other in a sequence
  113. # 'overlay' - all periods aligned to time zero
  114. all_spikes = [] # collect as groups
  115. sil_dur = 0
  116. for period in periods:
  117. idxs_tl_l, idxs_tl_r = period[0], period[1]
  118. spikes = s_times[(s_times > tl[idxs_tl_l][0]) & (s_times < tl[idxs_tl_r][0])]
  119. spikes -= tl[idxs_tl_l][0] # align to time 0
  120. if mode == 'sequence':
  121. spikes += sil_dur # adjust to already processed silence periods
  122. all_spikes.append(spikes)
  123. sil_dur += tl[idxs_tl_r][0] - tl[idxs_tl_l][0]
  124. return all_spikes #np.array([item for sublist in all_spikes for item in sublist])
  125. def get_spike_counts(spk_times, pulse_times, hw=0.25, bin_count=51):
  126. collected = []
  127. for t_pulse in pulse_times:
  128. selected = spk_times[(spk_times > t_pulse - hw) & (spk_times < t_pulse + hw)]
  129. collected += [x for x in selected - t_pulse]
  130. collected = np.array(collected)
  131. bins = np.linspace(-hw, hw, bin_count)
  132. counts, _ = np.histogram(collected, bins=bins)
  133. counts = counts / len(pulse_times) # * 1/((2. * hw)/float(bin_count - 1))
  134. counts = counts / (bins[1] - bins[0]) # divide by bin size to get firing rate
  135. return bins, counts