H_generateFigure6CD_Decoding2.m 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. %% Generate figure 06 and S07
  2. % decoding using PCA; Up left, compare 3 early training sessions across PCs
  3. % Up right, compare different ensemble size of extended training across PCs
  4. % code from David Ottenheimer 04/19/2018
  5. % decode region
  6. % first I'll try just classifying individual neurons as DMS or DLS
  7. %if ~exist('PSTHzDecode') %if you haven't already collected all the psth data for the decoding
  8. clear all
  9. close all
  10. load('Rearlytraining_light.mat')
  11. Xaxis1=[-0.25 0.25];
  12. Ishow=find(Tm>=Xaxis1(1) & Tm<=Xaxis1(2));
  13. for i=4:13 %loop thru ACQ session
  14. DLSselection=Ses(i).Coord(:,4)==10 & Ses(i).Celltype(:,1)==1;
  15. DMSselection=Ses(i).Coord(:,4)==20 & Ses(i).Celltype(:,1)==1;
  16. DLSconcat_ACQ=[];DMSconcat_ACQ=[];
  17. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(7).PSTHz(DLSselection,Ishow));
  18. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(7).PSTHz(DMSselection,Ishow));
  19. for j=1:6
  20. DLSconcat_ACQ = cat(2,DLSconcat_ACQ,Ses(i).Ev(j).PSTHz(DLSselection,Ishow));
  21. DMSconcat_ACQ = cat(2,DMSconcat_ACQ,Ses(i).Ev(j).PSTHz(DMSselection,Ishow));
  22. end
  23. PSTHzDecodeACQ(i-3).DLS=DLSconcat_ACQ(~isnan(mean(DLSconcat_ACQ,2)),:);
  24. PSTHzDecodeACQ(i-3).DMS=DMSconcat_ACQ(~isnan(mean(DMSconcat_ACQ,2)),:);
  25. end
  26. clear Ses
  27. load('Rextendedtraining_light.mat');
  28. load('Celltype_extendedtraining.mat');
  29. DLSselection=Coord(:,4)==10 & Celltype(:,1)==1;
  30. DMSselection=Coord(:,4)==20 & Celltype(:,1)==1;
  31. DLSconcat_OT = Ev(7).PSTHz(DLSselection,Ishow);
  32. DMSconcat_OT = Ev(7).PSTHz(DMSselection,Ishow);
  33. for i=1:6
  34. DLSconcat_OT = cat(2,DLSconcat_OT,Ev(i).PSTHz(DLSselection,Ishow));
  35. DMSconcat_OT = cat(2,DMSconcat_OT,Ev(i).PSTHz(DMSselection,Ishow));
  36. end
  37. PSTHzDecodeOT.DLS=DLSconcat_OT;
  38. PSTHzDecodeOT.DMS=DMSconcat_OT;
  39. save('PSTHzDecodeACQ','PSTHzDecodeOT')
  40. %% ACQ PCcomponents Comparison sessions.
  41. %PCA first
  42. %first do PCA on an equal number of neurons in each region
  43. nbsession=0;
  44. for k=1:3:10
  45. countRep=0;
  46. nbsession=nbsession+1;
  47. PSTHzDecode.DLS=PSTHzDecodeACQ(k).DLS;
  48. PSTHzDecode.DMS=PSTHzDecodeACQ(k).DMS;
  49. if length(PSTHzDecode.DLS(:,1))<length(PSTHzDecode.DMS(:,1))
  50. PCAneurons=length(PSTHzDecode.DLS(:,1));
  51. else
  52. PCAneurons=length(PSTHzDecode.DMS(:,1));
  53. end
  54. for l=1:50 % loop added to account for selection bias in VSel and NSel
  55. %pick which neurons to use
  56. SetupDMSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DMS(:,1))-PCAneurons,1));
  57. VSel=(SetupDMSel(randperm(length(SetupDMSel)))==1);
  58. %pick which neurons to use
  59. SetupDLSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DLS(:,1))-PCAneurons,1));
  60. NSel=(SetupDLSel(randperm(length(SetupDLSel)))==1);
  61. PCAneuronsACQ(nbsession,1)=PCAneurons;
  62. concat=cat(1,PSTHzDecode.DLS(NSel,:),PSTHzDecode.DMS(VSel,:))';
  63. [coeff,score,~,~,explained] = pca(concat);
  64. % Decoding
  65. clear Partitions SVMModel prediction actual correct
  66. folds = 10;%PCAneurons; %number of times cross-validated
  67. shuffs = 1; %number of times shuffled
  68. NumCoeffs = [1 2 3 4 5 6 7 8 9 10]; %number of PCs used in decoding
  69. repetitions=20; %how many times to run the analysis
  70. DiscrimType= 'linear';
  71. %setup the event identity matrix for decoding
  72. DecodeRegion(1:PCAneurons,1)=1;
  73. DecodeRegion((PCAneurons+1):(PCAneurons*2),1)=2;
  74. countRep=countRep+1;
  75. for i=1:length(NumCoeffs)
  76. clear CVacc CVaccSh
  77. %Setup spikes matrix for decoding
  78. DecodePCs=coeff(:,1:NumCoeffs(i));
  79. %ADD IN PCA HERE???
  80. %normal model
  81. for r = 1:folds
  82. Partitions = cvpartition(DecodeRegion,'KFold',folds);
  83. SVMModel = fitcdiscr(DecodePCs(Partitions.training(r),:),DecodeRegion(Partitions.training(r)),'DiscrimType',DiscrimType);
  84. prediction = predict(SVMModel,DecodePCs(Partitions.test(r),:));
  85. actual = DecodeRegion(Partitions.test(r));
  86. correct = prediction - actual;
  87. CVacc(r,1) = sum(correct==0) / length(correct);
  88. end
  89. %shuffled model
  90. for q=1:shuffs
  91. DecodeRsSh=DecodeRegion(randperm(length(DecodeRegion)));
  92. PartitionsSh = cvpartition(DecodeRsSh,'KFold',folds);
  93. for s = 1:folds
  94. SVMModelSh = fitcdiscr(DecodePCs(PartitionsSh.training(s),:),DecodeRsSh(PartitionsSh.training(s)),'DiscrimType',DiscrimType);
  95. predictionSh = predict(SVMModelSh,DecodePCs(PartitionsSh.test(s),:));
  96. actualSh = DecodeRsSh(PartitionsSh.test(s));
  97. correctSh = predictionSh - actualSh;
  98. CVaccSh(s,1) = sum(correctSh==0) / length(correctSh);
  99. end
  100. AccShuff(q,1) = nanmean(CVaccSh);
  101. end
  102. PCADecodeAccACQ{i,nbsession}(countRep,1)=nanmean(CVacc);
  103. PCADecodeAccShACQ{i,nbsession}(countRep,1)=nanmean(AccShuff);
  104. end
  105. fprintf(['RepACQ #' num2str(countRep) '\n']);
  106. end
  107. end
  108. %plotting
  109. %colors
  110. inh=[0 0.3333 1; 0 0.5 1; 0 0.67 1; 0 0.8353 1];
  111. exc=[1 0.1647 0; 0.9020 0.4510 0; 0.9020 0.6 0; 1 0.8353 0];
  112. %Y = prctile(X,p)
  113. for i=1:length(NumCoeffs)
  114. for k=1:4
  115. subplot(2,8,[1 2 3 4]);
  116. hold on;
  117. errorbar(NumCoeffs(i),nanmean(PCADecodeAccACQ{i,k}),nanstd(PCADecodeAccACQ{i,k}),'o','color',exc(k,:));
  118. errorbar(NumCoeffs(i),nanmean(PCADecodeAccShACQ{i,4}),nanstd(PCADecodeAccShACQ{i,4}),'o','color',inh(1,:));
  119. end
  120. end
  121. axis([0 NumCoeffs(end)+1 0.30 0.85]);
  122. title('Early training: Decoding DMS vs DLS');
  123. xlabel('Number of PCs');
  124. ylabel('Accuracy');
  125. %% Overtraining as a function of PC and number of neurons included
  126. for k=1:5
  127. countRep=0;
  128. PSTHzDecode.DLS=PSTHzDecodeOT.DLS;
  129. PSTHzDecode.DMS=PSTHzDecodeOT.DMS;
  130. nbneurons=[30 60 100 200 300];
  131. PCAneurons=nbneurons(k);
  132. for l=1:50 % loop added to account for selection bias in VSel and NSel
  133. %pick which neurons to use
  134. SetupDMSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DMS(:,1))-PCAneurons,1));
  135. VSel=(SetupDMSel(randperm(length(SetupDMSel)))==1);
  136. %pick which neurons to use
  137. SetupDLSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DLS(:,1))-PCAneurons,1));
  138. NSel=(SetupDLSel(randperm(length(SetupDLSel)))==1);
  139. concat=cat(1,PSTHzDecode.DLS(NSel,:),PSTHzDecode.DMS(VSel,:))';
  140. [coeff,score,~,~,explained] = pca(concat);
  141. %% Decoding
  142. clear Partitions SVMModel prediction actual correct DecodeRegion
  143. folds = 10;%PCAneurons; %number of times cross-validated
  144. shuffs = 1; %number of times shuffled
  145. NumCoeffs = [1 2 3 4 5 6 7 8 9 10]; %number of PCs used in decoding
  146. DiscrimType= 'linear';
  147. %setup the event identity matrix for decoding
  148. DecodeRegion(1:PCAneurons,1)=1;
  149. DecodeRegion((PCAneurons+1):(PCAneurons*2),1)=2;
  150. countRep=countRep+1;
  151. for i=1:length(NumCoeffs)
  152. clear CVacc CVaccSh
  153. %Setup spikes matrix for decoding
  154. DecodePCs=coeff(:,1:NumCoeffs(i));
  155. %ADD IN PCA HERE???
  156. %normal model
  157. for r = 1:folds
  158. Partitions = cvpartition(DecodeRegion,'KFold',folds);
  159. SVMModel = fitcdiscr(DecodePCs(Partitions.training(r),:),DecodeRegion(Partitions.training(r)),'DiscrimType',DiscrimType);
  160. prediction = predict(SVMModel,DecodePCs(Partitions.test(r),:));
  161. actual = DecodeRegion(Partitions.test(r));
  162. correct = prediction - actual;
  163. CVacc(r,1) = sum(correct==0) / length(correct);
  164. end
  165. %shuffled model
  166. for q=1:shuffs
  167. DecodeRsSh=DecodeRegion(randperm(length(DecodeRegion)));
  168. PartitionsSh = cvpartition(DecodeRsSh,'KFold',folds);
  169. for s = 1:folds
  170. SVMModelSh = fitcdiscr(DecodePCs(PartitionsSh.training(s),:),DecodeRsSh(PartitionsSh.training(s)),'DiscrimType',DiscrimType);
  171. predictionSh = predict(SVMModelSh,DecodePCs(PartitionsSh.test(s),:));
  172. actualSh = DecodeRsSh(PartitionsSh.test(s));
  173. correctSh = predictionSh - actualSh;
  174. CVaccSh(s,1) = sum(correctSh==0) / length(correctSh);
  175. end
  176. AccShuff(q,1) = nanmean(CVaccSh);
  177. end
  178. PCADecodeAccOT{i,k}(countRep,1)=nanmean(CVacc);
  179. PCADecodeAccShOT{i,k}(countRep,1)=nanmean(AccShuff);
  180. end
  181. fprintf(['RepOT #' num2str(countRep) '\n']);
  182. end
  183. end
  184. %% plotting
  185. %colors
  186. inh=[0 0 0.6; 0 0.3333 1; 0 0.5 1; 0 0.67 1; 0 0.8353 1];
  187. exc=[0.8 0 0; 1 0.1647 0; 0.9020 0.4510 0; 0.9020 0.6 0; 1 0.8353 0];
  188. for i=1:length(NumCoeffs)
  189. for k=1:5
  190. subplot(2,8,[5 6 7 8]);
  191. hold on;
  192. errorbar(NumCoeffs(i),nanmean(PCADecodeAccOT{i,k}),nanstd(PCADecodeAccOT{i,k}),'o','color',exc(k,:));
  193. errorbar(NumCoeffs(i),nanmean(PCADecodeAccShOT{i,5}),nanstd(PCADecodeAccShOT{i,5}),'o','color',inh(1,:));
  194. end
  195. end
  196. axis([0 NumCoeffs(end)+1 0.30 0.85]);
  197. title('Ext Training: Decoding DMS vs DLS');
  198. xlabel('Number of PCs');
  199. ylabel('Accuracy');
  200. %% Decoding DMS vs DLS across session, based on weight of 3 first PCs
  201. %PCA first
  202. for i=1:10
  203. countRep=0;
  204. PSTHzDecode.DLS=PSTHzDecodeACQ(i).DLS;
  205. PSTHzDecode.DMS=PSTHzDecodeACQ(i).DMS;
  206. if length(PSTHzDecode.DLS(:,1))<length(PSTHzDecode.DMS(:,1))
  207. PCAneurons=length(PSTHzDecode.DLS(:,1));
  208. else
  209. PCAneurons=length(PSTHzDecode.DMS(:,1));
  210. end
  211. for l=1:50
  212. SetupDMSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DMS(:,1))-PCAneurons,1));
  213. VSel=(SetupDMSel(randperm(length(SetupDMSel)))==1);
  214. %pick which neurons to use
  215. SetupDLSel=cat(1,ones(PCAneurons,1),zeros(length(PSTHzDecode.DLS(:,1))-PCAneurons,1));
  216. NSel=(SetupDLSel(randperm(length(SetupDLSel)))==1);
  217. concat=cat(1,PSTHzDecode.DLS(NSel,:),PSTHzDecode.DMS(VSel,:))';
  218. [coeff,score,~,~,explained] = pca(concat);
  219. %% Decoding
  220. clear Partitions SVMModel prediction actual correct DecodeRegion
  221. folds = 10;%PCAneurons; %number of times cross-validated
  222. shuffs = 1; %number of times shuffled
  223. NumCoeffs = 3; %number of PCs used in decoding
  224. DiscrimType= 'linear';
  225. %setup the event identity matrix for decoding
  226. DecodeRegion(1:PCAneurons,1)=1;
  227. DecodeRegion((PCAneurons+1):(PCAneurons*2),1)=2;
  228. countRep=countRep+1;
  229. clear CVacc CVaccSh
  230. %Setup spikes matrix for decoding
  231. DecodePCs=coeff(:,1:NumCoeffs);
  232. %ADD IN PCA HERE???
  233. %normal model
  234. for r = 1:folds
  235. Partitions = cvpartition(DecodeRegion,'KFold',folds);
  236. SVMModel = fitcdiscr(DecodePCs(Partitions.training(r),:),DecodeRegion(Partitions.training(r)),'DiscrimType',DiscrimType);
  237. prediction = predict(SVMModel,DecodePCs(Partitions.test(r),:));
  238. actual = DecodeRegion(Partitions.test(r));
  239. correct = prediction - actual;
  240. CVacc(r,1) = sum(correct==0) / length(correct);
  241. end
  242. %shuffled model
  243. for q=1:shuffs
  244. DecodeRsSh=DecodeRegion(randperm(length(DecodeRegion)));
  245. PartitionsSh = cvpartition(DecodeRsSh,'KFold',folds);
  246. for s = 1:folds
  247. SVMModelSh = fitcdiscr(DecodePCs(PartitionsSh.training(s),:),DecodeRsSh(PartitionsSh.training(s)),'DiscrimType',DiscrimType);
  248. predictionSh = predict(SVMModelSh,DecodePCs(PartitionsSh.test(s),:));
  249. actualSh = DecodeRsSh(PartitionsSh.test(s));
  250. correctSh = predictionSh - actualSh;
  251. CVaccSh(s,1) = sum(correctSh==0) / length(correctSh);
  252. end
  253. AccShuff(q,1) = nanmean(CVaccSh);
  254. end
  255. PCADecodeAccAA{i,1}(countRep,1)=nanmean(CVacc);
  256. PCADecodeAccShAA{i,1}(countRep,1)=nanmean(AccShuff);
  257. fprintf(['RepACQses #' num2str(countRep) '\n']);
  258. end
  259. end
  260. %% plotting
  261. %colors
  262. inh=[0 0 0.6; 0 0.3333 1; 0 0.5 1; 0 0.67 1; 0 0.8353 1];
  263. exc=[0.8 0 0; 1 0.1647 0; 0.9020 0.4510 0; 0.9020 0.6 0; 1 0.8353 0];
  264. NumSession=[1 2 3 4 5 6 7 8 9 10];
  265. for i=1:10
  266. subplot(2,8,[9 10 11 12]);
  267. hold on;
  268. errorbar(NumSession(i),nanmean(PCADecodeAccAA{i,1}),nanstd(PCADecodeAccAA{i,1}),'o','color',exc(1,:));
  269. errorbar(NumSession(i),nanmean(PCADecodeAccShAA{i,1}),nanstd(PCADecodeAccShAA{i,1}),'o','color',inh(1,:));
  270. end
  271. axis([0 10+1 0.35 0.85]);
  272. title('Decoding of region across session');
  273. xlabel('Sessions');
  274. ylabel('Accuracy');
  275. for k=1:5
  276. subplot(2,8,[13 14 15]);
  277. hold on;
  278. errorbar(nbneurons(k),nanmean(PCADecodeAccOT{3,k}),nanstd(PCADecodeAccOT{3,k}),'o','color',exc(1,:));
  279. errorbar(nbneurons(k),nanmean(PCADecodeAccShOT{3,k}),nanstd(PCADecodeAccShOT{3,k}),'o','color',inh(1,:));
  280. end
  281. axis([0 300 0.35 0.85]);
  282. xlabel('OT');
  283. ylabel('Accuracy');
  284. save('PCADecode.mat','PCADecodeAccAA','PCADecodeAccACQ','PCADecodeAccOT','PCADecodeAccShAA','PCADecodeAccShACQ','PCADecodeAccShOT');
  285. %% stat
  286. for i=1:10
  287. tableAcc(:,i)=cat(1,PCADecodeAccAA{i,1},PCADecodeAccShAA{i,1});
  288. Between_factor=cat(1,zeros(50,1),ones(50,1));
  289. [p,tbl]=anova1(tableAcc(:,i),Between_factor,'off');
  290. p_value(i)=p;
  291. stat(i).t=tbl;
  292. end
  293. for i=1:5
  294. tableAccOT(:,i)=cat(1,PCADecodeAccOT{3,i},PCADecodeAccShOT{3,i});
  295. [p,tbl]=anova1(tableAccOT(:,i),Between_factor,'off');
  296. pOT_value(i)=p;
  297. statOT(i).t=tbl;
  298. end
  299. tableAccACQ_OT=cat(1,PCADecodeAccAA{10,1},PCADecodeAccOT{3,1});
  300. [p_ACQ_OT,tbl_ACQ_OT]=anova1(tableAccACQ_OT,Between_factor,'off');
  301. %% Permutation test
  302. for i=1:length(PCADecodeAccAA)
  303. allsamples=cat(1,PCADecodeAccAA{i,1},PCADecodeAccShAA{i,1});
  304. for p=1:10000
  305. shuffsamples=allsamples(randperm(length(allsamples)));
  306. shuffdiff(p,1)=abs(mean(shuffsamples(1:length(allsamples)/2))-mean(shuffsamples(length(allsamples)/2+1:end)));
  307. end
  308. diff=abs(nanmean(PCADecodeAccAA{i,1})-nanmean(PCADecodeAccShAA{i,1}));
  309. trainingpval(i,1)=(sum(shuffdiff>diff)+1)/(length(shuffdiff)+1);
  310. end
  311. %% overtraining
  312. figure;
  313. for i=1:size(PCADecodeAccOT,2)
  314. allsamples=cat(1,PCADecodeAccOT{3,i},PCADecodeAccShOT{3,i});
  315. for p=1:10000
  316. shuffsamples=allsamples(randperm(length(allsamples)));
  317. shuffdiff(p,1)=abs(mean(shuffsamples(1:length(allsamples)/2))-mean(shuffsamples(length(allsamples)/2+1:end)));
  318. end
  319. diff=abs(nanmean(PCADecodeAccOT{1,i})-nanmean(PCADecodeAccShOT{1,i}));
  320. OTpval(1,i)=(sum(shuffdiff>diff)+1)/(length(shuffdiff)+1);
  321. subplot(5,1,i)
  322. hist(shuffdiff);
  323. hold on;
  324. plot([diff diff],[0 4000]);
  325. end