Scheduled service maintenance on November 22

On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience. 8.1 KB

  1. import argparse
  2. import numpy as np
  3. import pandas as pd
  4. from tqdm import tqdm
  5. from sklearn.svm import SVC
  6. from sklearn.model_selection import RepeatedStratifiedKFold, StratifiedKFold, cross_val_score
  7. from util import (load_data, get_trials, filter_units, get_psth, get_responses, circmean,
  8. angle_subtract)
  9. from phase_tuning import HHT
  10. from parameters import DATAPATH, NIMFCYCLES, NSPIKES, MINRATE
  11. if __name__ == "__main__":
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('e_name')
  14. args = parser.parse_args()
  15. NSPLITS = 5
  16. NPHASEBINS = 2
  17. PHASEBINS = np.linspace(-np.pi, np.pi, NPHASEBINS + 1)
  18. df_pupil = load_data('pupil', [args.e_name])
  19. df_trials = load_data('trials', [args.e_name])
  20. df_trials.rename(columns={'trial_on_time':'trial_on_times', 'trial_off_time':'trial_off_times'}, inplace=True)
  21. df_trials = df_trials.apply(get_trials, stim_id=0, axis='columns')
  22. df = pd.merge(df_pupil, df_trials).set_index(['m', 's', 'e'])
  23. df_spikes = load_data('spikes', [args.e_name]).set_index(['m', 's', 'e'])
  24. #df_spikes = df_spikes[df_spikes.index.isin(df_pupil.index)]
  25. df_spikes = filter_units(df_spikes, MINRATE)
  26. df_tuning = load_data('phasetuning', [args.e_name], tranges='noopto').set_index(['m', 's', 'e', 'u'])
  27. seriess = []
  28. for idx, row in tqdm(df.iterrows(), total=len(df)):
  29. pupil_area = row['pupil_area']
  30. pupil_tpts = row['pupil_tpts']
  31. pupil_fs = 1 / np.diff(pupil_tpts).mean()
  32. # Get IMFs
  33. hht = HHT(pupil_area, pupil_fs)
  34. hht.emd()
  35. hht.hsa()
  36. hht.check_number_of_phasebin_visits(ncycles=NIMFCYCLES, remove_invalid=True)
  37. imf_freqs = hht.characteristic_frequency
  38. imf_phases = hht.phase.T
  39. trial_starts = row['trial_on_times']
  40. #trial_stops = row['trial_off_times']
  41. #trial_duration = (trial_stops - trial_starts).mean()
  42. trial_duration = 5
  43. trial_starts = trial_starts[(trial_starts + trial_duration) < pupil_tpts.max()]
  44. stim_labels = np.repeat(np.arange(NSPLITS), len(trial_starts))
  45. # Get units for this experiment
  46. try:
  47. df_units = df_spikes.loc[idx]
  48. df_units = df_units.reset_index().set_index(['m', 's', 'e', 'u'])
  49. except KeyError:
  50. print("Spikes missing for {}".format(idx))
  51. continue
  52. for u_idx, unit in df_units.iterrows():
  53. print(u_idx)
  54. for imf_i, phase in enumerate(imf_phases):
  55. # Skip if no phase tuning analysis was done for this unit
  56. df_unittuning = df_tuning.loc[u_idx].query('imf == %d' % (imf_i + 1))
  57. if len(df_unittuning) < 1:
  58. continue
  59. print(imf_i)
  60. data = {
  61. 'm': idx[0],
  62. 's': idx[1],
  63. 'e': idx[2],
  64. 'u': u_idx[-1],
  65. 'imf': imf_i + 1,
  66. 'freq': imf_freqs[imf_i]
  67. }
  68. phase_raster, raster_tpts = get_responses(trial_starts, phase, pupil_tpts, post=trial_duration)
  69. split_inds = np.linspace(0, len(raster_tpts), NSPLITS + 1).astype(int)
  70. # Mean phase for each stimulus segment
  71. phase_means = np.concatenate(([circmean(phase_raster[:, i0:i1], axis=1)[1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])]))
  72. # Decode IMF phase using each spike type
  73. for spike_type in ['tonicspk', 'burst']:
  74. print(spike_type)
  75. # get raster for spike type and split into segments
  76. spike_times = unit['%s_times' % spike_type]
  77. if len(spike_times) <= NSPIKES:
  78. continue
  79. spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
  80. split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
  81. X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
  82. # use tuning phase to set phase bins
  83. tuning_phase = df_unittuning['%s_phase' % spike_type][0]
  84. if np.isnan(tuning_phase):
  85. continue
  86. phase_shift = angle_subtract(tuning_phase, -1 * np.pi / 2) - np.pi
  87. phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
  88. phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
  89. phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
  90. if len(np.unique(phase_labels)) < 2:
  91. raise RuntimeError
  92. # predict phase bin
  93. print("decoding phase")
  94. classifier = SVC(kernel='rbf')
  95. crossval = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
  96. scores = cross_val_score(classifier, X, phase_labels, cv=crossval)
  97. data['%s_phase' % spike_type] = scores.mean()
  98. ## Decode stimulus across phase bins using all spike times
  99. spike_times = unit['spk_times']
  100. if len(spike_times) <= NSPIKES:
  101. continue
  102. spike_raster, spike_tpts = get_psth(trial_starts, spike_times, post=trial_duration)
  103. #i0, i1 = spike_tpts.searchsorted([0, trial_duration])
  104. #spike_raster = spike_raster[:, i0:i1]
  105. #spike_tpts = spike_tpts[i0:i1]
  106. split_inds = np.linspace(0, len(spike_tpts), NSPLITS + 1).astype(int)
  107. X = np.row_stack([spike_raster[:, i0:i1] for i0, i1 in zip(split_inds[:-1], split_inds[1:])])
  108. # use tonic phase to set the phase bins
  109. tuning_phase = df_unittuning['tonicspk_phase'][0]
  110. if np.isnan(tuning_phase):
  111. continue
  112. phase_shift = angle_subtract(tuning_phase, -1 * np.pi / 2) - np.pi
  113. phase_means_shifted = angle_subtract(phase_means, phase_shift) - np.pi
  114. phase_labels = np.digitize(phase_means_shifted, bins=PHASEBINS)
  115. phase_labels = phase_labels.clip(1, NPHASEBINS) - 1
  116. if ((phase_labels == 0).mean() < 0.25) or ((phase_labels == 1).mean() < 0.25):
  117. print("Phase split biased")
  118. continue
  119. # split segments based on phase bin
  120. X1 = X[phase_labels.astype(bool)]
  121. y1 = stim_labels[phase_labels.astype(bool)]
  122. if len(np.unique(y1)) < 5:
  123. raise RuntimeError
  124. X2 = X[~phase_labels.astype(bool)]
  125. y2 = stim_labels[~phase_labels.astype(bool)]
  126. if len(np.unique(y2)) < 5:
  127. raise RuntimeError
  128. # train on phase bin 2 & test
  129. print("decoding stimulus, set 2")
  130. classifier = SVC(kernel='linear').fit(X2, y2)
  131. data['stim_train2_test2'] = classifier.score(X2, y2)
  132. data['stim_train2_test1'] = classifier.score(X1, y1)
  133. # train on the second phase bin & test
  134. print("decoding stimulus, set 1")
  135. classifier = SVC(kernel='linear').fit(X1, y1)
  136. data['stim_train1_test1'] = classifier.score(X1, y1)
  137. data['stim_train1_test2'] = classifier.score(X2, y2)
  138. # random
  139. splitter = RepeatedStratifiedKFold(n_splits=2, n_repeats=5, random_state=0)
  140. shf_diffs = np.full(10, np.nan)
  141. y = stim_labels
  142. for shf_i, (train, test) in enumerate(splitter.split(X, stim_labels)):
  143. classifier = SVC(kernel='linear').fit(X[train], y[train])
  144. shf_diffs[shf_i] = classifier.score(X[test], y[test]) - classifier.score(X[train], y[train])
  145. data['stim_testshf'] = shf_diffs
  146. seriess.append(pd.Series(data=data))
  147. df_decoding = pd.DataFrame(seriess)
  148. filename = 'imfdecoding_{}_norm.pkl'.format(args.e_name)
  149. df_decoding.to_pickle(DATAPATH + filename)