performance.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import json
  3. import h5py
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from scipy import signal
  7. def when_successful(traj, x_isl, y_isl, r_isl, t_sit):
  8. """
  9. traj - t, x, y - a matrix Nx3 of position data, equally sampled!
  10. """
  11. splits = np.where( (traj[:, 1] - x_isl)**2 + (traj[:, 2] - y_isl)**2 < r_isl**2 )[0]
  12. df = np.where(np.diff(splits) > 5)[0] # idxs of periods of starts
  13. if len(splits) > 0:
  14. periods = [[0, df[0] if len(df) > 0 else len(splits)-1]]
  15. if len(df) > 1:
  16. for point in df[1:]:
  17. periods.append( [periods[-1][1] + 1, point] )
  18. if len(df) > 0:
  19. periods.append([periods[-1][1] + 1, len(splits)-1])
  20. for period in periods:
  21. if traj[splits[period[1]]][0] - traj[splits[period[0]] - 5][0] > t_sit: # -5 is a hack
  22. return traj[splits[period[0]]][0] + t_sit
  23. return None
  24. def calculate_performance(tl, trial_idxs, cfg):
  25. """
  26. Returns a matrix of time_bins x metrics, usually (12 x 7) of
  27. performance_median, performance_upper_CI, performance_lower_CI, chance_median, chance_upper_CI, chance_lower_CI, time
  28. """
  29. arena_r = cfg['position']['floor_r_in_meters']
  30. target_r = cfg['experiment']['target_radius']
  31. t_sit = cfg['experiment']['target_duration']
  32. timepoints = cfg['experiment']['timepoints']
  33. s_duration = cfg['experiment']['session_duration']
  34. trial_time = tl[trial_idxs[:, 1].astype(np.int32)][:, 0] - tl[trial_idxs[:, 0].astype(np.int32)][:, 0]
  35. correct_trial = (trial_idxs[:, 5] == 1)
  36. time_bin_length = 5 # in secs
  37. N_time_slot = int(cfg['experiment']['trial_duration'] / time_bin_length) # 12 bins
  38. time_x_2plot = (np.arange(N_time_slot) + 1) * time_bin_length
  39. amount_trials = len(trial_time)
  40. amount_correct = np.zeros(N_time_slot, dtype=np.int32)
  41. for i, t_bin in enumerate(time_x_2plot):
  42. amount_correct[i] = len(np.where(trial_time < t_bin)[0])
  43. proportion_correct = amount_correct / amount_trials
  44. # bootstrapping real trials
  45. bs_count = 1000
  46. proportion_correct_bs = np.zeros((bs_count, N_time_slot))
  47. confidence_interval_real = np.zeros((2, N_time_slot))
  48. for i in range(N_time_slot):
  49. for bs in range(bs_count):
  50. temp_index = np.random.randint(0, amount_trials, amount_trials)
  51. temp_correct = np.zeros(amount_trials)
  52. temp_correct[:amount_correct[i]] = 1
  53. proportion_correct_bs[bs, i] = temp_correct[temp_index].sum() / float(amount_trials)
  54. confidence_interval_real[0, i] = np.percentile(proportion_correct_bs[:, i], 84.1) - np.median(proportion_correct_bs[:, i])
  55. confidence_interval_real[1, i] = np.percentile(proportion_correct_bs[:, i], 15.9) - np.median(proportion_correct_bs[:, i])
  56. # creating list of fake islands that will not overlap with target islands
  57. no_fake_islands = 1000
  58. fake_island_centers_x = np.empty((no_fake_islands, amount_trials))
  59. fake_island_centers_y = np.empty((no_fake_islands, amount_trials))
  60. fake_island_centers_x[:] = np.nan
  61. fake_island_centers_y[:] = np.nan
  62. for i in range(amount_trials):
  63. X_target, Y_target = trial_idxs[i][2], trial_idxs[i][3]
  64. count = 0
  65. while np.isnan(fake_island_centers_x[:, i]).any():
  66. angle = 2 * np.pi * np.random.rand()
  67. r = arena_r * np.sqrt(np.random.rand())
  68. x_temp = r * np.cos(angle) # add center of the arena if not centered
  69. y_temp = r * np.sin(angle) # add center of the arena if not centered
  70. if np.sqrt((x_temp - X_target)**2 + (y_temp - Y_target)**2) > 2 * target_r and \
  71. np.sqrt(x_temp**2 + y_temp**2) < arena_r - target_r:
  72. fake_island_centers_x[count, i] = x_temp
  73. fake_island_centers_y[count, i] = y_temp
  74. count += 1
  75. # surrogate islands work, now calculate the chance performance
  76. surrogate_correct = np.zeros((no_fake_islands, amount_trials))
  77. pos_downsample = 10 # think about reducing
  78. for trial in range(amount_trials):
  79. temp_index = np.arange(trial_idxs[trial][0], trial_idxs[trial][1], pos_downsample).astype(np.int32)
  80. temp_traj = tl[temp_index]
  81. temp_traj[:, 0] -= temp_traj[0][0] # time relative to trial start
  82. for surr in range(no_fake_islands):
  83. x_fake, y_fake = fake_island_centers_x[surr, trial], fake_island_centers_y[surr, trial]
  84. fake_island_time_finish = when_successful(temp_traj, x_fake, y_fake, target_r, t_sit)
  85. if fake_island_time_finish is not None:
  86. surrogate_correct[surr, trial] = fake_island_time_finish
  87. # now i have to do the same curve as in the real correct, but for the matrix surrogate_correct
  88. surr_for_deleting = np.array(surrogate_correct)
  89. proportion_correct_bs_fake = np.zeros((bs_count, N_time_slot))
  90. confidence_interval_bs_fake = np.zeros((2, N_time_slot))
  91. for time_slot in range(N_time_slot):
  92. fake_trials_to_remove = np.where(trial_time < (time_slot + 1) * time_bin_length)[0]
  93. for trial in fake_trials_to_remove:
  94. #idxs = np.logical_or(surrogate_correct[:, trial] == 0, surrogate_correct[:, trial] > trial_time[trial])
  95. idxs = np.where( (surrogate_correct[:, trial] == 0) | (surrogate_correct[:, trial] > trial_time[trial]) )[0]
  96. for idx in idxs:
  97. surr_for_deleting[idx, trial] = np.nan
  98. scwr = surr_for_deleting.flatten()
  99. scwr = scwr[~np.isnan(scwr)]
  100. for bs in range(bs_count):
  101. temp_index = np.random.randint(0, len(scwr), amount_trials)
  102. temp_correct = np.logical_and(scwr[temp_index] < (time_slot + 1) * time_bin_length, scwr[temp_index] > 0)
  103. proportion_correct_bs_fake[bs, time_slot] = temp_correct.sum() / float(amount_trials)
  104. confidence_interval_bs_fake[0, time_slot] = np.percentile(proportion_correct_bs_fake[:, time_slot], 84.1) - np.median(proportion_correct_bs_fake[:, time_slot])
  105. confidence_interval_bs_fake[1, time_slot] = np.percentile(proportion_correct_bs_fake[:, time_slot], 15.9) - np.median(proportion_correct_bs_fake[:, time_slot])
  106. # compute performance metrics
  107. c_median = 100 * np.median(proportion_correct_bs_fake, axis=0)
  108. c_lower_CI = 100 * confidence_interval_bs_fake[1]
  109. c_upper_CI = 100 * confidence_interval_bs_fake[0]
  110. p_median = 100 * np.median(proportion_correct_bs, axis=0)
  111. p_lower_CI = 100 * confidence_interval_real[1]
  112. p_upper_CI = 100 * confidence_interval_real[0]
  113. return np.column_stack([p_median, p_lower_CI, p_upper_CI, c_median, c_lower_CI, c_upper_CI, time_x_2plot])
  114. def dump_performance_to_H5(h5name, ds_name, dataset):
  115. with h5py.File(h5name, 'a') as f:
  116. if not 'analysis' in f:
  117. anal_group = f.create_group('analysis')
  118. anal_group = f['analysis']
  119. if ds_name in anal_group:
  120. del anal_group[ds_name]
  121. anal_group.create_dataset(ds_name, data=dataset)
  122. def get_finish_times(session_path):
  123. # loading session
  124. session = os.path.basename(os.path.normpath(session_path))
  125. h5name = os.path.join(session_path, session + '.h5')
  126. jsname = os.path.join(session_path, session + '.json')
  127. with open(jsname, 'r') as f:
  128. cfg = json.load(f)
  129. with h5py.File(h5name, 'r') as f:
  130. tl = np.array(f['processed']['timeline']) # time, X, Y, speed
  131. trial_idxs = np.array(f['processed']['trial_idxs']) # idx start, idx end, X, Y, R, trial result (idx to tl)
  132. islands = np.array(f['raw']['islands']) # tgt_x, tgt_y, tgt_r, d1_x, etc..
  133. # compute finish times
  134. finish_times = np.zeros((islands.shape[0], int(islands.shape[1]/3) ))
  135. for i, trial in enumerate(trial_idxs):
  136. traj = tl[int(trial[0]):int(trial[1])]
  137. current = islands[i].reshape((3, 3))
  138. for j, island in enumerate(current):
  139. finish_time = when_successful(traj, island[0], island[1], island[2], cfg['experiment']['target_duration'])
  140. if finish_time is not None:
  141. finish_times[i][j] = round(finish_time, 2)
  142. return finish_times
  143. def get_finish_times_rates(finish_times):
  144. isl_no = finish_times.shape[1]
  145. rates = np.zeros((isl_no + 1,)) # last one is when no island was successful
  146. for i in range(isl_no):
  147. isl_idx = isl_no - i - 1
  148. successful = finish_times[finish_times[:, isl_idx] > 0]
  149. count = 0
  150. for succ_trial in successful:
  151. finished_earlier = [x for x in succ_trial if x > 0 and x < succ_trial[isl_idx]]
  152. if len(finished_earlier) == 0:
  153. count += 1
  154. rates[isl_idx] = count
  155. # add fully unsuccessful trials
  156. rates[-1] = len([x for x in finish_times if x.sum() == 0])
  157. return rates