H_generateFigure7A_Decoding1.m 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. %% generate figure 7A. Decoding brain region ID without PCA
  2. % code from David Ottenheimer 04/19/2018
  3. %get the data
  4. clear all
  5. close all
  6. load('Rearlytraining_light.mat')
  7. Xaxis1=[-0.25 0.25];
  8. Ishow=find(Tm>=Xaxis1(1) & Tm<=Xaxis1(2));
  9. for i=4:13 %loop thru ACQ session
  10. DLSselection=Ses(i).Coord(:,4)==10 & Ses(i).Celltype(:,1)==1;
  11. DMSselection=Ses(i).Coord(:,4)==20 & Ses(i).Celltype(:,1)==1;
  12. DLSconcat_ACQ=[];DMSconcat_ACQ=[];
  13. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(7).Meanz(DLSselection,1));
  14. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(7).Meanz(DMSselection,1));
  15. for j=1:5
  16. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(j).MeanzPRE(DLSselection,1));
  17. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(j).Meanz(DLSselection,1));
  18. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(j).MeanzPRE(DMSselection,1));
  19. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(j).Meanz(DMSselection,1));
  20. end
  21. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(6).MeanzPRE(DLSselection,1));
  22. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(6).MeanzPRE(DMSselection,1));
  23. PSTHzDecodeACQ(i-3).DLS=DLSconcat_ACQ(~isnan(mean(DLSconcat_ACQ,2)),:);
  24. PSTHzDecodeACQ(i-3).DMS=DMSconcat_ACQ(~isnan(mean(DMSconcat_ACQ,2)),:);
  25. end
  26. clear Ses
  27. load('Rextendedtraining_light.mat');
  28. load('Celltype_extendedTraining.mat');
  29. DLSselection=Coord(:,4)==10 & Celltype(:,1)==1;
  30. DMSselection=Coord(:,4)==20 & Celltype(:,1)==1;
  31. DLSconcat_OT = Ev(7).Meanz(DLSselection,1);
  32. DMSconcat_OT = Ev(7).Meanz(DMSselection,1);
  33. for i=1:5
  34. DLSconcat_OT = cat(2,DLSconcat_OT,Ev(i).MeanzPRE(DLSselection,1));
  35. DLSconcat_OT = cat(2,DLSconcat_OT,Ev(i).Meanz(DLSselection,1));
  36. DMSconcat_OT = cat(2,DMSconcat_OT,Ev(i).MeanzPRE(DMSselection,1));
  37. DMSconcat_OT = cat(2,DMSconcat_OT,Ev(i).Meanz(DMSselection,1));
  38. end
  39. DLSconcat_OT = cat(2,DLSconcat_OT,Ev(6).MeanzPRE(DLSselection,1));
  40. DMSconcat_OT = cat(2,DMSconcat_OT,Ev(6).MeanzPRE(DMSselection,1));
  41. MeanzDecodeOT.DLS=DLSconcat_OT;
  42. MeanzDecodeOT.DMS=DMSconcat_OT;
  43. save('MeanzDecodeACQ','MeanzDecodeOT')
  44. %% ACQ PCcomponents Comparison sessions.
  45. %PCA first
  46. %first do PCA on an equal number of neurons in each region
  47. nbsession=0;
  48. for k=1:10
  49. clear rep
  50. countRep=0;
  51. PSTHzDecode.DLS=PSTHzDecodeACQ(k).DLS;
  52. PSTHzDecode.DMS=PSTHzDecodeACQ(k).DMS;
  53. if length(PSTHzDecode.DLS(:,1))<length(PSTHzDecode.DMS(:,1))
  54. PCAneurons=length(PSTHzDecode.DLS(:,1));
  55. else
  56. PCAneurons=length(PSTHzDecode.DMS(:,1));
  57. end
  58. for l=1:50 % loop added to account for selection bias in VSel and NSel
  59. %pick which neurons to use
  60. SetupDMSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DMS(:,1))-PCAneurons,1));
  61. VSel=(SetupDMSel(randperm(length(SetupDMSel)))==1);
  62. %pick which neurons to use
  63. SetupDLSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DLS(:,1))-PCAneurons,1));
  64. NSel=(SetupDLSel(randperm(length(SetupDLSel)))==1);
  65. rep(l).concat=cat(1,PSTHzDecode.DLS(NSel,:),PSTHzDecode.DMS(VSel,:));
  66. end
  67. %% Decoding
  68. clear Partitions SVMModel prediction actual correct DecodeRegion
  69. folds = 10;%PCAneurons; %number of times cross-validated
  70. shuffs = 1; %number of times shuffled
  71. DiscrimType= 'linear';
  72. %setup the event identity matrix for decoding
  73. DecodeRegion(1:PCAneurons,1)=1;
  74. DecodeRegion((PCAneurons+1):(PCAneurons*2),1)=2;
  75. for l=1:50
  76. countRep=countRep+1;
  77. clear CVacc CVaccSh
  78. %Setup spikes matrix for decoding
  79. DecodeMeanZ=rep(l).concat;
  80. %normal model
  81. for r = 1:folds
  82. Partitions = cvpartition(DecodeRegion,'KFold',folds);
  83. SVMModel = fitcdiscr(DecodeMeanZ(Partitions.training(r),:),DecodeRegion(Partitions.training(r)),'DiscrimType',DiscrimType);
  84. prediction = predict(SVMModel,DecodeMeanZ(Partitions.test(r),:));
  85. actual = DecodeRegion(Partitions.test(r));
  86. correct = prediction - actual;
  87. CVacc(r,1) = sum(correct==0) / length(correct);
  88. end
  89. %shuffled model
  90. for q=1:shuffs
  91. DecodeRsSh=DecodeRegion(randperm(length(DecodeRegion)));
  92. PartitionsSh = cvpartition(DecodeRsSh,'KFold',folds);
  93. for s = 1:folds
  94. SVMModelSh = fitcdiscr(DecodeMeanZ(PartitionsSh.training(s),:),DecodeRsSh(PartitionsSh.training(s)),'DiscrimType',DiscrimType);
  95. predictionSh = predict(SVMModelSh,DecodeMeanZ(PartitionsSh.test(s),:));
  96. actualSh = DecodeRsSh(PartitionsSh.test(s));
  97. correctSh = predictionSh - actualSh;
  98. CVaccSh(s,1) = sum(correctSh==0) / length(correctSh);
  99. end
  100. AccShuff(q,1) = nanmean(CVaccSh);
  101. end
  102. PCADecodeAccAA{k,1}(countRep,1)=nanmean(CVacc);
  103. PCADecodeAccShAA{k,1}(countRep,1)=nanmean(AccShuff);
  104. fprintf(['RepSelection #' num2str(countRep) '\n']);
  105. end
  106. end
  107. %% plotting
  108. %colors
  109. inh=[0 0 0.6];
  110. exc=[0.8 0 0];
  111. NumSession=[1 2 3 4 5 6 7 8 9 10];
  112. for i=1:10
  113. subplot(1,2,1);
  114. hold on;
  115. errorbar(NumSession(i),nanmean(PCADecodeAccAA{i,1}),nanstd(PCADecodeAccAA{i,1}),'o','color',exc);
  116. errorbar(NumSession(i),nanmean(PCADecodeAccShAA{i,1}),nanstd(PCADecodeAccShAA{i,1}),'o','color',inh);
  117. end
  118. axis([0 10+1 0.35 0.85]);
  119. title('Decoding of region across session');
  120. xlabel('Sessions');
  121. ylabel('Accuracy');
  122. legend({'True','Shuffled'},'location','northwest');
  123. %% Overtraining as a function of PC and number of neurons included
  124. for k=1:5
  125. countRep=0;
  126. clear rep
  127. PSTHzDecode.DLS=MeanzDecodeOT.DLS;
  128. PSTHzDecode.DMS=MeanzDecodeOT.DMS;
  129. nbneurons=[30 60 100 200 300];
  130. PCAneurons=nbneurons(k);
  131. for l=1:50 % loop added to account for selection bias in VSel and NSel
  132. SetupDMSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DMS(:,1))-PCAneurons,1));
  133. VSel=(SetupDMSel(randperm(length(SetupDMSel)))==1);
  134. %pick which neurons to use
  135. SetupDLSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DLS(:,1))-PCAneurons,1));
  136. NSel=(SetupDLSel(randperm(length(SetupDLSel)))==1);
  137. rep(l).concatOT=cat(1,PSTHzDecode.DLS(NSel,:),PSTHzDecode.DMS(VSel,:));
  138. end
  139. %% Decoding
  140. clear Partitions SVMModel prediction actual correct DecodeRegion
  141. folds = 10;%PCAneurons; %number of times cross-validated
  142. shuffs = 1; %number of times shuffled
  143. repetitions=20; %how many times to run the analysis
  144. DiscrimType= 'linear';
  145. %setup the event identity matrix for decoding
  146. DecodeRegion(1:PCAneurons,1)=1;
  147. DecodeRegion((PCAneurons+1):(PCAneurons*2),1)=2;
  148. for l=1:50
  149. countRep=countRep+1;
  150. clear CVacc CVaccSh
  151. %Setup spikes matrix for decoding
  152. DecodeMeanZ=rep(l).concatOT;
  153. %normal model
  154. for r = 1:folds
  155. Partitions = cvpartition(DecodeRegion,'KFold',folds);
  156. SVMModel = fitcdiscr(DecodeMeanZ(Partitions.training(r),:),DecodeRegion(Partitions.training(r)),'DiscrimType',DiscrimType);
  157. prediction = predict(SVMModel,DecodeMeanZ(Partitions.test(r),:));
  158. actual = DecodeRegion(Partitions.test(r));
  159. correct = prediction - actual;
  160. CVacc(r,1) = sum(correct==0) / length(correct);
  161. end
  162. %shuffled model
  163. for q=1:shuffs
  164. DecodeRsSh=DecodeRegion(randperm(length(DecodeRegion)));
  165. PartitionsSh = cvpartition(DecodeRsSh,'KFold',folds);
  166. for s = 1:folds
  167. SVMModelSh = fitcdiscr(DecodeMeanZ(PartitionsSh.training(s),:),DecodeRsSh(PartitionsSh.training(s)),'DiscrimType',DiscrimType);
  168. predictionSh = predict(SVMModelSh,DecodeMeanZ(PartitionsSh.test(s),:));
  169. actualSh = DecodeRsSh(PartitionsSh.test(s));
  170. correctSh = predictionSh - actualSh;
  171. CVaccSh(s,1) = sum(correctSh==0) / length(correctSh);
  172. end
  173. AccShuff(q,1) = nanmean(CVaccSh);
  174. end
  175. PCADecodeAccOT{1,k}(countRep,1)=nanmean(CVacc);
  176. PCADecodeAccShOT{1,k}(countRep,1)=nanmean(AccShuff);
  177. fprintf(['Rep #' num2str(countRep) '\n']);
  178. end
  179. end
  180. %% plotting
  181. %colors
  182. inh=[0 0 0.6];
  183. exc=[0.8 0 0];
  184. Shuffle=[];
  185. for k=1:5
  186. subplot(1,2,2);
  187. hold on;
  188. errorbar(nbneurons(k),nanmean(PCADecodeAccOT{1,k}),nanstd(PCADecodeAccOT{1,k}),'o','color',exc(1,:));
  189. errorbar(nbneurons(k),nanmean(PCADecodeAccShOT{1,k}),nanstd(PCADecodeAccShOT{1,k}),'o','color',inh(1,:));
  190. end
  191. axis([0.5 300 0.35 0.85]);
  192. title('Ext training: Decoding DMS vs DLS');
  193. xlabel('Ext Training - ensemble size');
  194. ylabel('Accuracy');
  195. %save('Decode_withoutPCA.mat','PCADecodeAccAA','PCADecodeAccOT','PCADecodeAccShAA','PCADecodeAccShOT');
  196. %% stat
  197. for i=1:10
  198. tableAcc(:,i)=cat(1,PCADecodeAccAA{i,1},PCADecodeAccShAA{i,1});
  199. Between_factor=cat(1,zeros(50,1),ones(50,1));
  200. [p,tbl]=anova1(tableAcc(:,i),Between_factor,'off');
  201. p_value(i)=p;
  202. stat(i).t=tbl;
  203. end
  204. for i=1:5
  205. tableAccOT(:,i)=cat(1,PCADecodeAccOT{1,i},PCADecodeAccShOT{1,i});
  206. [p,tbl]=anova1(tableAccOT(:,i),Between_factor,'off');
  207. pOT_value(i)=p;
  208. statOT(i).t=tbl;
  209. end
  210. tableAccACQ_OT=cat(1,PCADecodeAccAA{10,1},PCADecodeAccOT{1,1});
  211. [p_ACQ_OT,tbl_ACQ_OT]=anova1(tableAccACQ_OT,Between_factor,'off');
  212. %%
  213. for i=1:length(PCADecodeAccAA)
  214. allsamples=cat(1,PCADecodeAccAA{i,1},PCADecodeAccShAA{i,1});
  215. for p=1:10000
  216. shuffsamples=allsamples(randperm(length(allsamples)));
  217. shuffdiff(p,1)=abs(mean(shuffsamples(1:length(allsamples)/2))-mean(shuffsamples(length(allsamples)/2+1:end)));
  218. end
  219. diff=abs(nanmean(PCADecodeAccAA{i,1})-nanmean(PCADecodeAccShAA{i,1}));
  220. trainingpval(i,1)=(sum(shuffdiff>diff)+1)/(length(shuffdiff)+1);
  221. end
  222. %% overtraining
  223. for i=1:length(PCADecodeAccOT)
  224. allsamples=cat(1,PCADecodeAccOT{1,i},PCADecodeAccShOT{1,i});
  225. for p=1:10000
  226. shuffsamples=allsamples(randperm(length(allsamples)));
  227. shuffdiff(p,1)=abs(mean(shuffsamples(1:length(allsamples)/2))-mean(shuffsamples(length(allsamples)/2+1:end)));
  228. end
  229. diff=abs(nanmean(PCADecodeAccOT{1,i})-nanmean(PCADecodeAccShOT{1,i}));
  230. OTpval(1,i)=(sum(shuffdiff>diff)+1)/(length(shuffdiff)+1);
  231. end