pcaNetworkLoops.m 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. close all;
  2. clear all;
  3. avg_taus = [];
  4. taus_3 = [];
  5. mm_etas = [0:0.0125:2];
  6. %% one-time setup
  7. seed = 12345; %10 was working for many of these demos %56; %23, 10 for 100D 10 clusters
  8. % seed choice affects initial weight direction, which affects peaks of
  9. % error plots
  10. trackVars = 1; %turning on slows things WAY DOWN
  11. % Select Dataset
  12. %load('genData1.mat')
  13. % load('genData2D_4c_200ppc.mat')
  14. load('genData2D_3c_500ppc.mat')
  15. % load('genData100D_10.mat')
  16. % load('genData_realSpec_try2.mat')
  17. % load('genData_realSpec_5c.mat')
  18. % Define how data will be presented
  19. twoClust_dm_unshuff = allPts_dm_unshuff(1:100, :);
  20. twoClust_unshuff = allPts_unshuff(1:100, :);
  21. shuffIdx = 1:100;%randperm(500);
  22. part1_dm = twoClust_dm_unshuff(shuffIdx,:);
  23. part1 = twoClust_unshuff(shuffIdx, :);
  24. % bigY is input fed into clustering network - note that it's de-meaned
  25. bigY = [allPts_dm_unshuff']; %[part1_dm' allPts_dm'];%_unshuff'];% allPts_dm' allPts_dm'];%(1:500,:)'];% allPts_dm'];
  26. % mmY is input fed into mismatch network - just a non-de-meaned version of
  27. % bigY
  28. mmY = [allPts_unshuff']; %[part1' allPts'];%_unshuff'];% allPts' allPts'];%(1:500,:)'];% allPts'];
  29. % PCA CONFIGS
  30. pcaNet.changeThresh = 1e-4; % for output convergence
  31. initWeightMag = 0.1; %0.8 for some demos % should be 0.1
  32. pcaNet.capW = 10;%500; %max weight to each output cell -- CHANGING THIS DIDN'T MAKE A DIFFERENCE
  33. pcaNet.inhibCap = 10; %NEED 10 FOR SYNTHETIC, 25 FOR REAL... -- CHANGING THIS DOES MATTER
  34. pcaNet.etaW = 0.05; %originally each was 0.5 -- MAKING THIS TOO HIGH DOESN'T MATTER MUCH
  35. pcaNet.etaM = 0.05; % originally 0.5 -- CHANGING THIS DOESN'T MATTER MUCH
  36. pcaNet.maxM = 1;
  37. pcaNet.maxW = 1;
  38. % MISMATCH CONFIGS
  39. % CHANGING THIS DRASTICALLY AFFECTS SHAPE OF ERROR PLOTS
  40. mmNet.thresh = 0.05; % CHANGING THIS AFFECTS SHAPE OF ERROR PLOTS
  41. % MISTMATCH STRUCTURE CONFIGS
  42. mmNet.signed_synapses = 1; %force positive weight coeffs
  43. mmNet.c_plastic = 1; %
  44. mmNet.y_plastic = 0;
  45. % MISMATCH ARCHITECTURE CONFIGS
  46. yE_type = 'rand'; %'rand' or 'randnorm'
  47. cE_type = 'rand';
  48. yI_type = 'rand';
  49. cI_type = 'rand';
  50. yR_type = 'direct'; %'rand' or 'direct'
  51. cR_type = 'direct'; %'rand' or 'direct'
  52. % MISMATCH NETWORK CONFIGS
  53. nCells_y = size(mmY,1); %one input per input dim
  54. nCells_c = 10; % (500 works on synthetic) 10 clusters in 100D case (nb 100d is a lot)
  55. nCells_Ny = 5; %not super sure how high or low this needs to be
  56. nCells_Nc = 5; %but 100 and 100 should be enough to give us convex cone
  57. iterations = size(bigY,2);
  58. % Total Configs -- this doesn't matter, we're always learning
  59. % (learningThresh = 0)
  60. sigmaThresh = 2.5; %1 works for seed 10 %2.5;%2.5;%%0.25;%25; %1 worked well
  61. deltaLearn = 0.2; %(0.1) % no buffer -> incorrect clustering (not obvious why needed for shuffled data)
  62. learningThresh = 0;%1; %make it deltaLearn to basically eliminate deltaLearn %0.5; %always learning -> incorrect clustering
  63. pcaNet.sigmaThresh = sigmaThresh;
  64. pcaNet.learning = 1; %this is set to 0 if mismatch is low, 1 if high
  65. learningSig = 1;
  66. rng(seed)
  67. %% looping sections
  68. for mIdx = 1:length(mm_etas)
  69. mmNet.eta = mm_etas(mIdx);
  70. disp(['loop ' num2str(mIdx) ' of ' num2str(length(mm_etas))])
  71. W_init = initWeightMag*randn(nCells_c, nCells_y);
  72. M_init = 0*initWeightMag*rand(nCells_c); %initially 0
  73. for idx = 1:nCells_c
  74. M_init(idx, idx) = 0; %nrns don't drive selves
  75. end
  76. D_init = zeros(nCells_c,1);
  77. pcaNet.bigC = zeros(nCells_c, iterations);
  78. pcaNet.W = W_init;
  79. pcaNet.M = M_init;
  80. pcaNet.D = D_init;
  81. c = updateC_v5(...
  82. pcaNet.W, pcaNet.M, bigY(:,1), pcaNet.changeThresh);
  83. [maxx, thisCluster] = max(c);
  84. if maxx > 0
  85. pcaNet.clusters = thisCluster;
  86. pcaNet.bigC(thisCluster, 1) = c(thisCluster);
  87. else
  88. pcaNet.clusters = 0;
  89. end
  90. cT = zeros(size(c));
  91. cT(thisCluster) = c(thisCluster); % c at timestep t
  92. y = bigY(:,1);
  93. rng(seed);
  94. % Mismatch Setup
  95. switch yE_type
  96. case 'rand'
  97. mmNet.we_yn = (rand(nCells_Ny, nCells_y))./nCells_y;
  98. case 'randnorm'
  99. mmNet.we_yn = (randn(nCells_Ny, nCells_y))./nCells_y;
  100. end
  101. switch yI_type
  102. case 'rand'
  103. mmNet.wi_yn = (rand(nCells_Nc, nCells_y))./nCells_y;
  104. case 'randnorm'
  105. mmNet.wi_yn = (randn(nCells_Nc, nCells_y))./nCells_y;
  106. end
  107. switch yR_type
  108. case 'rand'
  109. mmNet.r_yn = rand(nCells_y);
  110. case 'direct'
  111. mmNet.r_yn = eye(nCells_y);
  112. end
  113. switch cE_type
  114. case 'rand'
  115. mmNet.we_cn = (rand(nCells_Nc, nCells_c));%./nCells_c;
  116. %
  117. case 'randnorm'
  118. mmNet.we_cn = (randn(nCells_Nc, nCells_c));%./nCells_c;
  119. end
  120. switch cI_type
  121. case 'rand'
  122. mmNet.wi_cn = (rand(nCells_Ny, nCells_c));%./nCells_c;
  123. %
  124. case 'randnorm'
  125. mmNet.wi_cn = (randn(nCells_Ny, nCells_c));%./nCells_c;
  126. end
  127. switch cR_type
  128. case 'rand'
  129. mmNet.r_cn = rand(nCells_c);
  130. case 'direct'
  131. mmNet.r_cn = eye(nCells_c);
  132. end
  133. timesteps = length(bigY(1,:));
  134. if trackVars == 1
  135. mmNet.Vs_y = zeros(nCells_Ny, timesteps);
  136. mmNet.Vs_c = zeros(nCells_Nc, timesteps);
  137. mmNet.Fs_y = zeros(nCells_Ny, timesteps);
  138. mmNet.Fs_c = zeros(nCells_Nc, timesteps);
  139. mmNet.yWs_e = zeros(nCells_Ny, nCells_y, timesteps+1);
  140. mmNet.cWs_e = zeros(nCells_Nc, nCells_c, timesteps+1);
  141. mmNet.yWs_i = zeros(nCells_Nc, nCells_y, timesteps+1);
  142. mmNet.cWs_i = zeros(nCells_Ny, nCells_c, timesteps+1);
  143. mmNet.wyChanges_e = zeros(nCells_Ny, nCells_y, timesteps);
  144. mmNet.wcChanges_e = zeros(nCells_Nc, nCells_c, timesteps);
  145. mmNet.wyChanges_i = zeros(nCells_Nc, nCells_y, timesteps);
  146. mmNet.wcChanges_i = zeros(nCells_Ny, nCells_c, timesteps);
  147. mmNet.yWs_e(:,:,1) = mmNet.we_yn;
  148. mmNet.cWs_e(:,:,1) = mmNet.we_cn;
  149. mmNet.yWs_i(:,:,1) = mmNet.wi_yn;
  150. mmNet.cWs_i(:,:,1) = mmNet.wi_cn;
  151. end
  152. mmNet.errors_y = zeros(timesteps, 1);
  153. mmNet.errors_c = zeros(timesteps, 1);
  154. mmNet.allErrors = zeros(timesteps,1);
  155. % we've already done 1 iter of pca, so here do 1 iter of mm
  156. mmNet = mismatchIter_v2(cT, mmY, 1, mmNet, trackVars);
  157. disp('setup complete')
  158. % Run through algorithm (make functions for PCAIter and MNIter)
  159. for ts_idx = 2:iterations
  160. pcaNet = pcaIter_v6(bigY, ts_idx, pcaNet);
  161. pcaC = pcaNet.bigC(:, ts_idx);
  162. pcaC = pcaC>eps;
  163. mmNet = mismatchIter_v2(pcaC, mmY, ts_idx, mmNet, trackVars);
  164. sigmaNode = mmNet.allErrors(ts_idx);
  165. if sigmaNode > sigmaThresh
  166. learningSig = learningSig + deltaLearn;
  167. if learningSig > 1
  168. learningSig = 1;
  169. end
  170. else
  171. learningSig = learningSig - deltaLearn;
  172. if learningSig < 0
  173. learningSig = 0;
  174. end
  175. end
  176. pcaNet.learning = 1;
  177. end
  178. % what are taus?
  179. down1 = mmNet.allErrors(1:2);
  180. down2 = mmNet.allErrors(501:502);
  181. down3 = mmNet.allErrors(1001:1002);
  182. decay1 = (down1(1) - down1(end));
  183. decay2 = (down2(1) - down2(end));
  184. decay3 = (down3(1) - down3(end));
  185. mean_decay = (decay1+decay2+decay3)/3;
  186. taus_3 = [taus_3; decay3/down3(1)];
  187. avg_taus = [avg_taus; mean_decay];
  188. % colors = {'r.', 'y.', 'c.', 'g.', 'm.', 'k.', 'b.'};
  189. % for cIdx = 1:7
  190. % thisClust = find(pcaNet.clusters == cIdx);
  191. % figure(5);
  192. % plot(bigY(1, thisClust), bigY(2, thisClust), colors{cIdx})
  193. % % if ts_idx == 2
  194. % hold on;
  195. % % end
  196. % end
  197. figure(888); plot(mmNet.allErrors, 'k-'); xlim([995, 1025]); hold on;
  198. plot(1001:1002,mmNet.allErrors(1001:1002),'b*'); hold off;
  199. figure(999); plot(mm_etas(1:mIdx), taus_3(1:mIdx)); makepretty; xlim([0,max(mm_etas)]); pause(eps);
  200. % close(5);
  201. end