ClassDef_AmplitudeShift_Stable.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from neo import Spike2IO
  2. import numpy as np
  3. from scipy import stats
  4. from pathlib import Path
  5. import pandas as pd
  6. from sklearn.cluster import KMeans
  7. class AmpShift_Stable():
  8. def _init_(self):
  9. print('object created')
  10. def load_expt(self,exptname,data_folder):
  11. # exptpath = Path.cwd().resolve().parents[0]
  12. # data_folder = exptpath / 'data_raw'
  13. # file_to_open = data_folder / exptname[0:exptname.find('_')]/ Path(exptname + '.smr')
  14. file_to_open = data_folder / Path(exptname + '.smr')
  15. print(file_to_open)
  16. bl = Spike2IO(file_to_open.as_posix(),try_signal_grouping=False).read_block()
  17. seg = bl.segments[0]
  18. self.seg = seg
  19. self.exptname = exptname
  20. def set_amps(self,n_amp,amp_array):
  21. self.n_amp = n_amp
  22. self.amp_array = amp_array
  23. def set_channels(self,trigger_chan,vm_chan,spk_chan,siu_chan,marker_chan):
  24. self.trigger_chan = trigger_chan
  25. self.vm_chan = vm_chan
  26. self.spk_chan = spk_chan
  27. self.siu_chan = siu_chan
  28. self.marker_chan = marker_chan
  29. def get_marker_table(self):
  30. markerT = np.asarray(self.seg.events[[s.name for s in self.seg.events].index(self.marker_chan)].magnitude)
  31. # markerC = np.asarray([int(s.decode("utf-8")) for s in self.seg.events[[s.name for s in self.seg.events].index(self.marker_chan)].labels]);
  32. markerC = np.asarray([int(s) for s in self.seg.events[[s.name for s in self.seg.events].index(self.marker_chan)].labels]);
  33. markerC = np.asarray([chr(int(hex(n),16)) for n in markerC])
  34. marker_df = pd.DataFrame({
  35. 'time' : markerT,
  36. 'code' : markerC,
  37. })
  38. return marker_df
  39. def get_events(self,channame):
  40. events = np.asarray(self.seg.events[[s.name for s in self.seg.events].index(channame)].magnitude)
  41. return events
  42. def get_signal(self,channame):
  43. signal = self.seg.analogsignals[[s.name for s in self.seg.analogsignals].index(channame)]
  44. return signal
  45. def get_bout_win(self,boutstr,marker_chan):
  46. codeS = np.asarray([int(s) for s in self.seg.events[[s.name for s in self.seg.events].index(marker_chan)].get_labels()])
  47. codeT = np.asarray(self.seg.events[[s.name for s in self.seg.events].index(marker_chan)].magnitude)
  48. stopstr = 'S'
  49. startcode = ord(boutstr)
  50. stopcode = ord(stopstr)
  51. alloff = codeT[codeS==stopcode]
  52. bouton = codeT[codeS == startcode]
  53. boutoff = [np.min(alloff[alloff>t]) for t in bouton]
  54. bout_win = [[o,f] for o,f in zip(bouton,boutoff)]
  55. return bout_win
  56. def filter_marker_df_time(self,marker_df,time_win):
  57. sub_df = marker_df[(marker_df['time']>time_win[0][0]) & (marker_df['time']<time_win[0][1])]
  58. if np.shape(time_win)[0]>1:
  59. for t in time_win[1:]:
  60. sub_df = pd.concat([sub_df,marker_df[(marker_df['time']>t[0]) & (marker_df['time']<t[1])]])
  61. return sub_df
  62. def filter_marker_df_code(self,marker_df,codestr):
  63. sub_df = marker_df[marker_df['code'].str.match(codestr[0])]
  64. if np.shape(codestr)[0]>1:
  65. for c in codestr[1:]:
  66. sub_df = pd.concat([sub_df,marker_df[sub_df['code'].str.match(c)]])
  67. return sub_df
  68. def filter_events(self,events,time_win):
  69. these_events = []
  70. for sublist in np.asarray([events[np.where((events>win[0])&(events<win[1]))] for win in time_win]):
  71. for item in sublist:
  72. these_events.append(item)
  73. these_events = np.asarray(these_events)
  74. return these_events
  75. def get_sweepsmat(self,signal_chan,times,sweepdur):
  76. thischan = self.seg.analogsignals[[s.name for s in self.seg.analogsignals].index(signal_chan)]
  77. sweepsmat = []
  78. v = thischan.magnitude
  79. dt = float(thischan.sampling_period)
  80. nsamp = int(sweepdur/dt)
  81. for t in times:
  82. inds = [int(t/dt),int(t/dt)+nsamp]
  83. sweepsmat.append(v[inds[0]:inds[1]].flatten())
  84. sweepsmat = np.asarray(sweepsmat).T
  85. xtime = np.linspace(0,sweepdur,int(sweepdur/dt))*1000
  86. return xtime,sweepsmat
  87. def get_dt(self,signal_chan):
  88. thischan = self.seg.analogsignals[[s.name for s in self.seg.analogsignals].index(signal_chan)]
  89. dt = float(thischan.sampling_period)
  90. return dt
  91. def cluster_event_Amp(self,event_Amp,event_0_Amp):
  92. ampshift = np.asarray([((A/event_0_Amp)*100)-100 for A in event_Amp]).reshape(-1, 1)
  93. km = KMeans(n_clusters=self.n_amp)
  94. p = km.fit_predict(ampshift)
  95. c = [_c[0] for _c in km.cluster_centers_]
  96. for i,v in zip(np.argsort(c),self.amp_array):
  97. c[i] = v
  98. ampshift_round = [c[v] for v in p]
  99. return ampshift, ampshift_round