pcaPlusMN_2D_v4.m 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. %%% pca + mismatch network
  2. % y comes into pca network, which runs one iter. (immediately clusters)
  3. % output of pca network is unary context signal
  4. % y and context signal are then fed into mismatch network
  5. % pca network starts with relatively low excitability, and excitability
  6. % increases every time total mismatch is over some threshold
  7. %%% Important notes:
  8. %%% -this specific implementation could fail if we're unlucky with the
  9. %%% intial directions of the weights (ie if no weight is close enough
  10. %%% to input vector, we'd get into trouble)
  11. %%% -this version has been edited to work with very high dim inputs!
  12. %%% -v2 tries to introduce normalizations to mm weights
  13. %%%
  14. %%%
  15. %%% Things to check!
  16. %%% -Mismatch error should be decreasing over time, although that may
  17. %%% take a while
  18. %%% -Rank of W should initially be close to 100 (or whatever
  19. %%% initialized), but should end up close to number of clusters
  20. %%% -Weight magnitudes (for W and M) shouldn't be too crazy
  21. %%% -Cluster IDs should be consistent - after running network once,
  22. %%% re-running on same points should give same clusters (unless network
  23. %%% takes very long to converge)
  24. %%% -Mismatch error should be much lower if same data run on trained network
  25. %%% -Context signal should be relatively sparse!
  26. clear
  27. close all
  28. h1 = figure(1);
  29. set(h1, 'Position', [2209 -31 512 384])
  30. h2 = figure(5);
  31. set(h2, 'Position', [1370 -319 822 672])
  32. % figure(10)
  33. % figure(25)
  34. pause
  35. %% Configs
  36. seed = 12345; %10 was working for many of these demos %56; %23, 10 for 100D 10 clusters
  37. % seed choice affects initial weight direction, which affects peaks of
  38. % error plots
  39. %seed = 1123;
  40. trackVars = 1; %turning on slows things WAY DOWN
  41. % Select Dataset
  42. %load('genData1.mat')
  43. % load('genData2D_4c_200ppc.mat')
  44. load('genData2D_3c_500ppc.mat')
  45. % load('genData100D_10.mat')
  46. % load('genData_realSpec_try2.mat')
  47. % load('genData_realSpec_5c.mat')
  48. % Define how data will be presented
  49. twoClust_dm_unshuff = allPts_dm_unshuff(1:100, :);
  50. twoClust_unshuff = allPts_unshuff(1:100, :);
  51. shuffIdx = 1:100;%randperm(500);
  52. part1_dm = twoClust_dm_unshuff(shuffIdx,:);
  53. part1 = twoClust_unshuff(shuffIdx, :);
  54. % bigY is input fed into clustering network - note that it's de-meaned
  55. bigY = [allPts_dm_unshuff']; %[part1_dm' allPts_dm'];%_unshuff'];% allPts_dm' allPts_dm'];%(1:500,:)'];% allPts_dm'];
  56. % mmY is input fed into mismatch network - just a non-de-meaned version of
  57. % bigY
  58. mmY = [allPts_unshuff']; %[part1' allPts'];%_unshuff'];% allPts' allPts'];%(1:500,:)'];% allPts'];
  59. % PCA CONFIGS
  60. pcaNet.changeThresh = 1e-4; % for output convergence
  61. initWeightMag = 0.1; %0.8 for some demos % should be 0.1
  62. pcaNet.capW = 10;%500; %max weight to each output cell -- CHANGING THIS DIDN'T MAKE A DIFFERENCE
  63. pcaNet.inhibCap = 10; %NEED 10 FOR SYNTHETIC, 25 FOR REAL... -- CHANGING THIS DOES MATTER
  64. pcaNet.etaW = 0.05; %originally each was 0.5 -- MAKING THIS TOO HIGH DOESN'T MATTER MUCH
  65. pcaNet.etaM = 0.05; % originally 0.5 -- CHANGING THIS DOESN'T MATTER MUCH
  66. pcaNet.maxM = 1;
  67. pcaNet.maxW = 1;
  68. % MISMATCH CONFIGS
  69. mmNet.eta = 0.025; % CHANGING THIS DRASTICALLY AFFECTS SHAPE OF ERROR PLOTS
  70. mmNet.thresh = 0.05; % CHANGING THIS AFFECTS SHAPE OF ERROR PLOTS
  71. % MISTMATCH STRUCTURE CONFIGS
  72. mmNet.signed_synapses = 1; %force positive weight coeffs
  73. mmNet.c_plastic = 1; %
  74. mmNet.y_plastic = 0;
  75. % MISMATCH ARCHITECTURE CONFIGS
  76. yE_type = 'rand'; %'rand' or 'randnorm'
  77. cE_type = 'rand';
  78. yI_type = 'rand';
  79. cI_type = 'rand';
  80. yR_type = 'direct'; %'rand' or 'direct'
  81. cR_type = 'direct'; %'rand' or 'direct'
  82. % MISMATCH NETWORK CONFIGS
  83. nCells_y = size(mmY,1); %one input per input dim
  84. nCells_c = 10; % (500 works on synthetic) 10 clusters in 100D case (nb 100d is a lot)
  85. nCells_Ny = 5; %not super sure how high or low this needs to be
  86. nCells_Nc = 5; %but 100 and 100 should be enough to give us convex cone
  87. iterations = size(bigY,2);
  88. % Total Configs -- this doesn't matter, we're always learning
  89. % (learningThresh = 0)
  90. sigmaThresh = 2.5; %1 works for seed 10 %2.5;%2.5;%%0.25;%25; %1 worked well
  91. deltaLearn = 0.2; %(0.1) % no buffer -> incorrect clustering (not obvious why needed for shuffled data)
  92. learningThresh = 0;%1; %make it deltaLearn to basically eliminate deltaLearn %0.5; %always learning -> incorrect clustering
  93. pcaNet.sigmaThresh = sigmaThresh;
  94. pcaNet.learning = 1; %this is set to 0 if mismatch is low, 1 if high
  95. learningSig = 1;
  96. rng(seed)
  97. %% PCA - Setup
  98. W_init = initWeightMag*randn(nCells_c, nCells_y);
  99. M_init = 0*initWeightMag*rand(nCells_c); %initially 0
  100. for idx = 1:nCells_c
  101. M_init(idx, idx) = 0; %nrns don't drive selves
  102. end
  103. D_init = zeros(nCells_c,1);
  104. pcaNet.bigC = zeros(nCells_c, iterations);
  105. pcaNet.W = W_init;
  106. pcaNet.M = M_init;
  107. pcaNet.D = D_init;
  108. c = updateC_v5(...
  109. pcaNet.W, pcaNet.M, bigY(:,1), pcaNet.changeThresh);
  110. [maxx, thisCluster] = max(c);
  111. if maxx > 0
  112. pcaNet.clusters = thisCluster;
  113. pcaNet.bigC(thisCluster, 1) = c(thisCluster);
  114. else
  115. pcaNet.clusters = 0;
  116. end
  117. cT = zeros(size(c));
  118. cT(thisCluster) = c(thisCluster); % c at timestep t
  119. y = bigY(:,1);
  120. rng(seed);
  121. %% Mismatch Setup
  122. switch yE_type
  123. case 'rand'
  124. mmNet.we_yn = (rand(nCells_Ny, nCells_y))./nCells_y;
  125. case 'randnorm'
  126. mmNet.we_yn = (randn(nCells_Ny, nCells_y))./nCells_y;
  127. end
  128. switch yI_type
  129. case 'rand'
  130. mmNet.wi_yn = (rand(nCells_Nc, nCells_y))./nCells_y;
  131. case 'randnorm'
  132. mmNet.wi_yn = (randn(nCells_Nc, nCells_y))./nCells_y;
  133. end
  134. switch yR_type
  135. case 'rand'
  136. mmNet.r_yn = rand(nCells_y);
  137. case 'direct'
  138. mmNet.r_yn = eye(nCells_y);
  139. end
  140. switch cE_type
  141. case 'rand'
  142. mmNet.we_cn = (rand(nCells_Nc, nCells_c));%./nCells_c;
  143. %
  144. case 'randnorm'
  145. mmNet.we_cn = (randn(nCells_Nc, nCells_c));%./nCells_c;
  146. end
  147. switch cI_type
  148. case 'rand'
  149. mmNet.wi_cn = (rand(nCells_Ny, nCells_c));%./nCells_c;
  150. %
  151. case 'randnorm'
  152. mmNet.wi_cn = (randn(nCells_Ny, nCells_c));%./nCells_c;
  153. end
  154. switch cR_type
  155. case 'rand'
  156. mmNet.r_cn = rand(nCells_c);
  157. case 'direct'
  158. mmNet.r_cn = eye(nCells_c);
  159. end
  160. timesteps = length(bigY(1,:));
  161. if trackVars == 1
  162. mmNet.Vs_y = zeros(nCells_Ny, timesteps);
  163. mmNet.Vs_c = zeros(nCells_Nc, timesteps);
  164. mmNet.Fs_y = zeros(nCells_Ny, timesteps);
  165. mmNet.Fs_c = zeros(nCells_Nc, timesteps);
  166. mmNet.yWs_e = zeros(nCells_Ny, nCells_y, timesteps+1);
  167. mmNet.cWs_e = zeros(nCells_Nc, nCells_c, timesteps+1);
  168. mmNet.yWs_i = zeros(nCells_Nc, nCells_y, timesteps+1);
  169. mmNet.cWs_i = zeros(nCells_Ny, nCells_c, timesteps+1);
  170. mmNet.wyChanges_e = zeros(nCells_Ny, nCells_y, timesteps);
  171. mmNet.wcChanges_e = zeros(nCells_Nc, nCells_c, timesteps);
  172. mmNet.wyChanges_i = zeros(nCells_Nc, nCells_y, timesteps);
  173. mmNet.wcChanges_i = zeros(nCells_Ny, nCells_c, timesteps);
  174. mmNet.yWs_e(:,:,1) = mmNet.we_yn;
  175. mmNet.cWs_e(:,:,1) = mmNet.we_cn;
  176. mmNet.yWs_i(:,:,1) = mmNet.wi_yn;
  177. mmNet.cWs_i(:,:,1) = mmNet.wi_cn;
  178. end
  179. mmNet.errors_y = zeros(timesteps, 1);
  180. mmNet.errors_c = zeros(timesteps, 1);
  181. mmNet.allErrors = zeros(timesteps,1);
  182. % we've already done 1 iter of pca, so here do 1 iter of mm
  183. mmNet = mismatchIter_v2(cT, mmY, 1, mmNet, trackVars);
  184. disp('setup complete')
  185. %% Run through algorithm (make functions for PCAIter and MNIter)
  186. for ts_idx = 2:iterations
  187. pcaNet = pcaIter_v6(bigY, ts_idx, pcaNet);
  188. pcaC = pcaNet.bigC(:, ts_idx);
  189. pcaC = pcaC>eps;
  190. mmNet = mismatchIter_v2(pcaC, mmY, ts_idx, mmNet, trackVars);
  191. sigmaNode = mmNet.allErrors(ts_idx);
  192. if sigmaNode > sigmaThresh
  193. learningSig = learningSig + deltaLearn;
  194. if learningSig > 1
  195. learningSig = 1;
  196. end
  197. else
  198. learningSig = learningSig - deltaLearn;
  199. if learningSig < 0
  200. learningSig = 0;
  201. end
  202. end
  203. if learningSig >= learningThresh
  204. pcaNet.learning = 1;
  205. else
  206. pcaNet.learning = 0;
  207. end
  208. % agnoPlotting_2D(bigY, mmNet, pcaNet, ts_idx, iterations)
  209. % agnoPlotting_100D(bigY, mmNet, pcaNet, ts_idx, iterations)
  210. end
  211. agnoPlotting_2D(bigY, mmNet, pcaNet, ts_idx, iterations)
  212. %% Plotting
  213. %
  214. % % Cluster ID Distribution
  215. % figure(2)
  216. % subplot(2,2,[1 2])
  217. % histogram(pcaNet.clusters(1:iterations), [-0.5:1:nCells_c+0.5])
  218. % title('Inputs per Cluster - Total Dataset')
  219. % xlabel('Cluster ID')
  220. % ylabel('Number of Inputs')
  221. %
  222. % subplot(2,2,3)
  223. % histogram(pcaNet.clusters(1:iterations/2), [-0.5:1:nCells_c+0.5])
  224. % title('Inputs per Cluster - First Half')
  225. % xlabel('Cluster ID')
  226. % ylabel('Number of Inputs')
  227. %
  228. % subplot(2,2,4)
  229. % histogram(pcaNet.clusters((iterations/2)+1:end), [-0.5:1:nCells_c+0.5])
  230. % title('Inputs per Cluster - Second Half')
  231. % xlabel('Cluster ID')
  232. % ylabel('Number of Inputs')