show_feedback_data.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import yaml
  4. import aux
  5. from helpers import data_management as dm
  6. import modules.classifier as clf
  7. import analytics
  8. import os
  9. import itertools
  10. from collections.abc import Iterable
  11. params = aux.load_config(force_reload=True)
  12. if params.speller.type == 'exploration':
  13. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  14. exploration=True, trigger_pos=params.classifier.trigger_pos)
  15. elif params.speller.type == 'feedback':
  16. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  17. feedback=True, trigger_pos=params.classifier.trigger_pos)
  18. else:
  19. data_tot, tt, triggers_tot, ch_rec_list, file_names = dm.get_raw(n_triggers=params.classifier.n_classes,
  20. trigger_pos=params.classifier.trigger_pos)
  21. if params.classifier.trigger_pos == 'start':
  22. psth_win = [-2.0, 3.0] # in seconds!
  23. else:
  24. psth_win = [-5.0, 3.0] # in seconds!
  25. psth_win = np.array(psth_win) * 1000.0 / params.daq.spike_rates.loop_interval
  26. psth_win = psth_win.astype(int)
  27. for fids in range(data_tot.shape[0]):
  28. data = data_tot[fids,0]
  29. triggers = triggers_tot[fids,0][0]
  30. psth_down = np.zeros((triggers_tot[0,1].shape[1], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  31. psth_up = np.zeros((triggers_tot[0,0].shape[1], np.diff(psth_win)[0], data_tot[0,0].shape[1])) # state x #stimuli x psth_window x #channels
  32. for ii,tr_id in enumerate(triggers_tot[0,1][0]):
  33. if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
  34. psth_down[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  35. print(ii,tr_id)
  36. for ii,tr_id in enumerate(triggers_tot[0,0][0]):
  37. if tr_id+(psth_win[0])>=0 and tr_id+(psth_win[1])<=data.shape[0]:
  38. psth_up[ii,:,:] = data[tr_id+psth_win[0]:tr_id+psth_win[1],:]
  39. print(ii,tr_id)
  40. psth_xx = np.arange(psth_win[0],psth_win[1]) / 20.
  41. col = ['C0', 'C1']
  42. plt.figure(1)
  43. plt.clf()
  44. ### Calculate normalized firing rates for the PSTHs, according to configuration
  45. norm_rate = {}
  46. norm_rate['ch_ids'] = np.asarray([ch.id for ch in params.daq.normalization.channels])
  47. norm_rate['bottoms'] = np.asarray([ch.bottom for ch in params.daq.normalization.channels])
  48. norm_rate['tops'] = np.asarray([ch.top for ch in params.daq.normalization.channels])
  49. norm_rate['invs'] = [ch.invert for ch in params.daq.normalization.channels]
  50. clamped_rates_down = np.maximum(np.minimum(psth_down[:, :, norm_rate['ch_ids']], norm_rate['tops']), norm_rate['bottoms'])
  51. clamped_rates_up = np.maximum(np.minimum(psth_up[:, :, norm_rate['ch_ids']], norm_rate['tops']), norm_rate['bottoms'])
  52. norm_rates_down = (clamped_rates_down - norm_rate['bottoms']) / (norm_rate['tops'] - norm_rate['bottoms'])
  53. norm_rates_up = (clamped_rates_up - norm_rate['bottoms']) / (norm_rate['tops'] - norm_rate['bottoms'])
  54. norm_rates_up[:,:,norm_rate['invs']] = 1 - norm_rates_up[:,:,norm_rate['invs']]
  55. norm_rates_down[:,:,norm_rate['invs']] = 1 - norm_rates_down[:,:,norm_rate['invs']]
  56. ### Calculate firing rate average across channels and then normalize (to simulate 'use_all_channels' option)
  57. achcfg = params.daq.normalization.all_channels
  58. all_ch_rates_down = np.maximum(np.minimum(np.squeeze(np.nanmean(psth_down, axis=2)), achcfg.top), achcfg.bottom)
  59. all_ch_rates_up = np.maximum(np.minimum(np.squeeze(np.nanmean(psth_up, axis=2)), achcfg.top), achcfg.bottom)
  60. all_ch_rates_down = (all_ch_rates_down - achcfg.bottom) / (achcfg.top - achcfg.bottom)
  61. all_ch_rates_up = (all_ch_rates_up - achcfg.bottom) / (achcfg.top - achcfg.bottom)
  62. if achcfg.invert:
  63. all_ch_rates_down = 1.0 - all_ch_rates_down
  64. all_ch_rates_up = 1.0 - all_ch_rates_up
  65. # for multiple channels used for control, use each single one, the whole set, and each subset with one fewer item than the whole set
  66. n_ch = len(norm_rate['ch_ids'])
  67. ch_len_list = [i for i in [1, n_ch - 1, n_ch] if i != 0]
  68. ix_list = list(itertools.chain.from_iterable(itertools.combinations(range(len(norm_rate['ch_ids'])), i) for i in set(ch_len_list) ))
  69. for ch_id in [*range(128), *ix_list, 'all'] :
  70. if ch_id == 'all':
  71. psth_down_agg = all_ch_rates_down
  72. psth_up_agg = all_ch_rates_up
  73. filter_min_rate = False
  74. elif isinstance(ch_id, Iterable):
  75. psth_down_agg = np.squeeze(np.mean(norm_rates_down[:,:,ch_id], axis=2))
  76. psth_up_agg = np.squeeze(np.mean(norm_rates_up[:,:,ch_id], axis=2))
  77. filter_min_rate = False
  78. ch_id = np.sort(norm_rate['ch_ids'][np.array(ch_id)])
  79. else:
  80. psth_down_agg = psth_down[:, :, ch_id]
  81. psth_up_agg = psth_up[:, :, ch_id]
  82. filter_min_rate = params.plot.filter_min_rate
  83. ymin = 0
  84. ymax = max(psth_down_agg.max(), psth_up_agg.max())
  85. mu1 = np.mean(psth_down_agg, axis=0)
  86. md1 = np.median(psth_down_agg, axis=0)
  87. mu2 = np.mean(psth_up_agg, axis=0)
  88. md2 = np.median(psth_up_agg, axis=0)
  89. # if the mean firing rate is less than 4 Hz in all conditions and at all times, the channel
  90. # is probably not interesting
  91. if filter_min_rate and max(mu1.max(),mu2.max()) < filter_min_rate:
  92. continue
  93. plt.figure(1)
  94. plt.clf()
  95. plt.subplot(221)
  96. plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
  97. plt.plot(psth_xx, mu1.T, color='k', alpha=0.8)
  98. plt.plot(psth_xx, md1.T, '--', color='k', alpha=0.8)
  99. # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
  100. # plt.ylim(0,8)
  101. plt.ylabel('sp/sec')
  102. plt.ylim(ymin, ymax)
  103. plt.title(f'down, n={psth_down_agg.shape[0]}')
  104. plt.subplot(222)
  105. plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
  106. plt.plot(psth_xx, mu2.T, color='k', alpha=0.8)
  107. plt.plot(psth_xx, md2.T, '--', color='k', alpha=0.8)
  108. plt.title(f'up, n={psth_up_agg.shape[0]}')
  109. plt.ylim(ymin, ymax)
  110. plt.subplot(223)
  111. plt.plot(psth_xx, psth_down_agg.T, 'C0', alpha=0.5)
  112. plt.plot(psth_xx, psth_up_agg.T, 'C1', alpha=0.5)
  113. plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
  114. plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
  115. plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
  116. plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
  117. plt.ylim(ymin, ymax)
  118. plt.ylabel('sp/sec')
  119. plt.xlabel('sec')
  120. plt.subplot(224)
  121. # plt.plot(psth_xx, psth_down[:, :, ch_id].T, 'C0', alpha=0.5)
  122. # plt.plot(psth_xx, psth_up[:, :, ch_id].T, 'C1', alpha=0.5)
  123. plt.plot(psth_xx, mu1.T, color='C0', alpha=1., lw=2)
  124. plt.plot(psth_xx, md1.T, '--', color='C0', alpha=0.8, lw=2)
  125. plt.plot(psth_xx, mu2.T, color='C1', alpha=1., lw=2)
  126. plt.plot(psth_xx, md2.T, '--', color='C1', alpha=0.8, lw=2)
  127. plt.ylim(ymin, ymax)
  128. plt.xlabel('sec')
  129. fname = os.path.basename(file_names[0]).split('.')[0]
  130. os.makedirs(f'{params.file_handling.results}/{fname}', exist_ok=True)
  131. # plt.savefig(f'/media/vlachos/bck_disk1/kiap/recordings/fr/results/neurofeedback/{fname}_nf_{ch_id}.png')
  132. fname2 = f'{params.file_handling.results}/{fname}/{fname}_nf_{ch_id}.png'
  133. print(fname2)
  134. plt.savefig(fname2)
  135. # plt.plot(psth_xx, np.median(mu1, axis=0), color=col[ii],lw=2)
  136. # plt.ylim(0,8)
  137. # input()
  138. # plt.draw()
  139. # plt.show()
  140. plt.show()