analytics2.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import importlib
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.decomposition.pca import PCA
  5. import aux
  6. from helpers import data_management as dm
  7. from analytics import analytics1
  8. importlib.reload(analytics1)
  9. params_plot = {'legend.fontsize': 'small',
  10. 'figure.figsize': (10, 5),
  11. 'axes.labelsize': 'small',
  12. 'axes.titlesize':'small',
  13. 'xtick.labelsize':'small',
  14. 'ytick.labelsize':'small'}
  15. plt.rcParams.update(params_plot)
  16. params = aux.load_config()
  17. data_tot, tt, triggers, ch_rec_list = dm.get_raw(n_triggers=params.classifier.n_classes)
  18. if params.classifier.max_active_ch_nr != []:
  19. params.classifier.include_channels = np.argsort(np.max(data_tot[0,0],axis=0))[-params.classifier.max_active_ch_nr:] # get channels with high firing rates
  20. if params.classifier.exclude_channels == []:
  21. exclude_channels = list(set(range(128))-set(params.classifier.include_channels))
  22. data_tot = analytics1.exclude_channels(data_tot, exclude_channels) # exclude channels from analysis
  23. if params.daq.spike_rates.correct_bl_model:
  24. data_tot = analytics1.correct_init_baseline(data_tot, idx1, params.daq.spike_rates.bl_offset) #correct baseline as well
  25. for block_id in range(data_tot.shape[0]):
  26. # get channels with highest firing rate
  27. data = data_tot[block_id,0]
  28. ch_ids = np.argsort(np.max(data,axis=0))[-20:]
  29. print(f'session: {block_id}, ch_ids with highest max firing rates: {ch_ids}')
  30. pca1 = PCA(n_components=2)
  31. pca1.fit(data)
  32. res1 = pca1.transform(data)
  33. cc1 = np.corrcoef(np.max(data, axis=0), np.abs(pca1.components_[0]))
  34. cc2 = np.corrcoef(np.max(data, axis=0), np.abs(pca1.components_[1]))
  35. print(f'block: {block_id}, corr coefs: {cc1[0,1]}, {cc2[0,1]}')
  36. # plt.plot(np.max(data,axis=0))
  37. # plt.plot(pca1.components_)
  38. # show PC1, PC2
  39. plt.figure(1)
  40. plt.clf()
  41. plt.plot(res1[:,0])
  42. plt.plot(res1[:,1])
  43. plt.draw()
  44. # show all channels
  45. plt.figure(2)
  46. plt.clf()
  47. plt.imshow(data.T, aspect='auto')
  48. plt.draw()
  49. for block_id in range(data_tot.shape[0]):
  50. plt.figure(figsize=(6,6))
  51. plt.clf()
  52. data = data_tot[block_id,0]
  53. for ii in range(data.shape[1]):
  54. plt.subplot(8,8,ii+1)
  55. plt.hist(data[:,ii],bins=range(25))
  56. plt.xlim(-1,30)
  57. # plt.ylim(0,2000)
  58. if ii==56:
  59. plt.xlabel('Sp/sec', fontsize=8)
  60. if ii !=56:
  61. plt.xticks([])
  62. plt.yticks([])
  63. plt.suptitle(f'Firing rate distributions, array: M1, block: {block_id}')
  64. plt.draw()
  65. plt.show()
  66. fs = 20.
  67. fc = [1, 0]
  68. order = 5
  69. data1 = data[:,55:56]
  70. data2,_,_ = analytics1.bw_filter(data1.T, fc, fs, order, plot=True)
  71. xx = np.arange(data1.shape[0]) / fs
  72. Y, ff = analytics1.calc_fft(data2, fs)
  73. plt.figure(3)
  74. plt.clf()
  75. plt.subplot(211)
  76. plt.plot(xx, data1[:, 0], 'C1', label='original')
  77. plt.plot(xx, data2[:, 0], 'C0', label='filtered')
  78. plt.xlabel('Time (sec)')
  79. plt.ylabel('Sp/sec')
  80. plt.legend()
  81. plt.subplot(212)
  82. plt.plot(ff, Y[0, :], 'C0', label='power\nspectrum')
  83. plt.xlim(0, 1)
  84. plt.xlabel('Frequency (Hz)')
  85. plt.legend()
  86. plt.show()