runme.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import pylab as pl
  2. import numpy as np
  3. from os import listdir, makedirs
  4. from os.path import isfile, join, isdir
  5. import seaborn as sns
  6. import scipy.optimize as opt
  7. from Helpers.file_helpers import data_dictionary_di
  8. experiments = [2,3,4,5,6,8,9,10,11,12,13,15,19,21,24,25]
  9. pathtofiles = '/Users/ms81/Desktop/Disco_plots'
  10. datapath = pathtofiles + '/Lightdisco_raw/'
  11. figure_path = pathtofiles + '/img/'
  12. days = []
  13. disco_flag = []
  14. culture_flag = []
  15. cvs = []
  16. info = []
  17. rates = []
  18. waveforms = []
  19. for ex in experiments:
  20. print('Collecting...' + str(ex))
  21. di = data_dictionary_di(ex, verbose = 'True', infofile = '/Users/ms81/Desktop/Disco_plots/data.txt')
  22. age = di[ex]['DIV']
  23. discoflag = di[ex]['disco']
  24. good_channels_file = di[ex]['channels']
  25. culture = di[ex]['culture']
  26. saveresults = datapath + 'rawdata_' + str(ex)
  27. a = np.load(saveresults + '.npy', encoding='latin1')
  28. final_results = a.item()
  29. elenumber = final_results['data_dict']['elenumber']
  30. data = final_results['electrode_data']
  31. sorted_dict = final_results['sorted_dict']
  32. mean_waveforms = sorted_dict['mean_waveforms_good']
  33. for ele in range(len(elenumber)):
  34. if elenumber[ele] in good_channels_file:
  35. spikes = data['Electrode_' + str(elenumber[ele])]
  36. mu = np.mean(np.diff(spikes))
  37. sigma = np.std(np.diff(spikes))
  38. cv = sigma/mu
  39. rate = len(spikes)/np.max(spikes) #number of spike, until the last one
  40. print( str(elenumber[ele]) + ' ' + str(cv) + ' ' + str(rate) )
  41. cvs.append(cv)
  42. rates.append(rate)
  43. days.append(age)
  44. info.append(elenumber[ele])
  45. disco_flag.append(discoflag)
  46. culture_flag.append(culture)
  47. waveforms.append(mean_waveforms[ele])
  48. print('Collection done, now printing.')
  49. #~ sns.set_context("talk")
  50. sns.set_style("whitegrid")
  51. colors = sns.color_palette("Set3",n_colors=7)
  52. days_ref = []
  53. for s in days:
  54. if s ==12 or s == 13 or s == 15 or s == 16:
  55. days_ref.append(14)
  56. if s == 20 or s == 23:
  57. days_ref.append(21)
  58. if s == 9:
  59. days_ref.append(7)
  60. pl.figure(1, figsize = (15,5))
  61. pl.subplot(1,2,1)
  62. pl.title('CVs, Lightdisco vs. controls')
  63. sns.boxplot(x=days_ref, y=cvs, hue = disco_flag, palette=colors)
  64. pl.ylim([0,6])
  65. pl.xlabel('Age [DIV]')
  66. pl.ylabel('Coefficient of Variation')
  67. pl.subplot(1,2,2)
  68. pl.title('Firing rates, Lightdisco vs. controls')
  69. sns.boxplot(x=days_ref, y=rates, hue = disco_flag, palette="Set3")
  70. pl.ylim([0,5])
  71. pl.xlabel('Age [DIV]')
  72. pl.ylabel('Rates')
  73. pl.savefig(figure_path + 'fig1.pdf', bbox_inches='tight')
  74. pl.close()
  75. pl.figure(2, figsize = (15,5))
  76. pl.subplot(1,2,1)
  77. for i in range(len(days)):
  78. if disco_flag[i] == 'yes':
  79. pl.plot(days[i],cvs[i],'.', markerfacecolor = colors[int(culture_flag[i][0])])
  80. else:
  81. pl.plot(days[i]+.2,cvs[i],'b.')
  82. #~ pl.text(days[i]+.5,cvs[i],info[i])
  83. pl.xlabel('DIV')
  84. pl.ylabel('Coefficient of Variation')
  85. pl.ylim([0,6])
  86. pl.subplot(1,2,2)
  87. for i in range(len(days)):
  88. if disco_flag[i] == 'yes':
  89. pl.plot(days[i],rates[i],'.', markerfacecolor = colors[int(culture_flag[i][0])])
  90. else:
  91. pl.plot(days[i]+.2,rates[i],'b.')
  92. #~ pl.text(days[i]+.5,cvs[i],info[i])
  93. pl.xlabel('DIV')
  94. pl.ylabel('Firing Rate')
  95. pl.ylim([0,5])
  96. pl.savefig(figure_path + 'fig2.pdf', bbox_inches='tight')
  97. pl.close()
  98. pl.figure(3, figsize = (15,5))
  99. for i in range(len(days)):
  100. if disco_flag[i] == 'yes':
  101. pl.plot(cvs[i],rates[i],'r.')
  102. else:
  103. pl.plot(cvs[i],rates[i],'b.')
  104. pl.xlabel('CVs')
  105. pl.ylabel('Rate')
  106. pl.xlim([0,5])
  107. pl.ylim([0,15])
  108. pl.savefig(figure_path + 'fig3.pdf', bbox_inches='tight')
  109. pl.close()
  110. #This is the waveform figure
  111. def make_gaussian(xx, x0, stdx, a):
  112. xx = xx - x0
  113. gauss = a*np.exp(-(xx*xx/(2.*stdx*stdx)))
  114. return gauss
  115. ref = []
  116. fit = []
  117. fit_error = []
  118. discoshifted = []
  119. controlshifted = []
  120. c1, c2 = sns.color_palette("Set3", 2)
  121. pl.figure(4, figsize = (8,10))
  122. xx = np.linspace(0, 6, 150) #0ms to 6 ms in 40mus steps
  123. for i in range(len(rates)):
  124. wv = waveforms[i]
  125. normalized = wv*1e6 #-wv/np.min(wv)
  126. minwhere_ms = np.argmin(normalized)*6./150
  127. popt, pcov = opt.curve_fit(make_gaussian, xx, wv, p0=(2,0.2,-1e-4))
  128. perr = np.sqrt(np.diag(pcov))
  129. gaussian_fitted = make_gaussian(xx, *popt)
  130. fit.append(popt)
  131. fit_error.append(perr)
  132. if disco_flag[i] == 'yes':
  133. pl.subplot(3,2,1)
  134. pl.plot(xx-minwhere_ms, normalized, color = c1, alpha = .1)
  135. pl.ylim([-130,70])
  136. ref.append(0)
  137. wvsh = np.array([0 for i in range(300)])
  138. start = int(minwhere_ms*150/6.)
  139. wvsh[start:(150+start)] = normalized
  140. discoshifted.append(wvsh)
  141. else:
  142. pl.subplot(3,2,2)
  143. pl.plot(xx-minwhere_ms, normalized, color = c2, alpha = .05)
  144. pl.ylim([-130,70])
  145. ref.append(1)
  146. wvsh = np.array([0 for i in range(300)])
  147. start = int(minwhere_ms*150/6.)
  148. wvsh[start:(150+start)] = normalized
  149. controlshifted.append(wvsh)
  150. fit = np.array(fit)
  151. pl.subplot(3,2,3)
  152. sns.boxplot(x=ref, y=np.abs(fit[:,1]), palette="Set3")
  153. pl.ylabel("width [ms]")
  154. pl.ylim([0,0.35])
  155. pl.subplot(3,2,4)
  156. sns.boxplot(x=ref, y=np.abs(fit[:,2])*1e6, palette="Set3")
  157. pl.ylabel("Amplitude [muV]")
  158. pl.ylim([0,160])
  159. pl.subplot(3,2,5)
  160. pl.fill_between(x=np.arange(300)*6./150, y1=np.mean(discoshifted,axis=0) - np.std(discoshifted,axis=0), y2=np.mean(discoshifted,axis=0) + np.std(discoshifted,axis=0),alpha=0.5,facecolor = c1)
  161. pl.plot(np.arange(300)*6./150,np.mean(discoshifted,axis=0),color = c1)
  162. pl.fill_between(x=np.arange(300)*6./150, y1=np.mean(controlshifted,axis=0) - np.std(controlshifted,axis=0), y2=np.mean(controlshifted,axis=0) + np.std(controlshifted,axis=0), alpha=0.5, facecolor = c2)
  163. pl.plot(np.arange(300)*6./150,np.mean(controlshifted,axis=0), color = c2)
  164. pl.xlim([2,8])
  165. pl.subplot(3,2,6)
  166. controldistro = np.mean(controlshifted,axis=0)
  167. dt = 6/150.
  168. ref = []
  169. dd = []
  170. for i in range(len(rates)):
  171. wv = waveforms[i]
  172. normalized = wv*1e6 #-wv/np.min(wv)
  173. minwhere_ms = np.argmin(normalized)*6./150
  174. if disco_flag[i] == 'yes':
  175. ref.append(0)
  176. wvsh = np.array([0 for i in range(300)])
  177. start = int(minwhere_ms*150/6.)
  178. wvsh[start:(150+start)] = normalized
  179. dd.append(np.sqrt(dt*np.sum(controldistro - wvsh)**2))
  180. else:
  181. ref.append(1)
  182. wvsh = np.array([0 for i in range(300)])
  183. start = int(minwhere_ms*150/6.)
  184. wvsh[start:(150+start)] = normalized
  185. dd.append(np.sqrt(dt*np.sum(controldistro - wvsh)**2))
  186. sns.boxplot(x=ref, y=dd, palette="Set3")
  187. pl.ylabel("Deviation from control mean")
  188. pl.savefig(figure_path + 'fig4.pdf', bbox_inches='tight')
  189. pl.close()