pcaPlusMN_HD_v4.m 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. %%% pca + mismatch network
  2. % y comes into pca network, which runs one iter. (immediately clusters)
  3. % output of pca network is unary context signal
  4. % y and context signal are then fed into mismatch network
  5. % pca network starts with relatively low excitability, and excitability
  6. % increases every time total mismatch is over some threshold
  7. %%% Important notes:
  8. %%% -this specific implementation could fail if we're unlucky with the
  9. %%% intial directions of the weights (ie if no weight is close enough
  10. %%% to input vector, we'd get into trouble)
  11. %%% -this version has been edited to work with very high dim inputs!
  12. %%% -v2 tries to introduce normalizations to mm weights
  13. %%%
  14. %%%
  15. %%% Things to check!
  16. %%% -Mismatch error should be decreasing over time, although that may
  17. %%% take a while
  18. %%% -Rank of W should initially be close to 100 (or whatever
  19. %%% initialized), but should end up close to number of clusters (maybe)
  20. %%% -Weight magnitudes (for W and M) shouldn't be too crazy
  21. %%% -Cluster IDs should be consistent - after running network once,
  22. %%% re-running on same points should give same clusters (unless network
  23. %%% takes very long to converge)
  24. %%% -Mismatch error should be much lower if same data run on trained network
  25. %%% -Context signal should be relatively sparse!
  26. clear
  27. close all
  28. figure(1)
  29. % figure(5)
  30. % figure(10)
  31. % figure(25)
  32. pause
  33. %% Configs
  34. seed = 1211; %23, 10 for 100D 10 clusters
  35. trackVars = 0; %turning on slows things WAY DOWN
  36. errorPlotting = 0;
  37. % load('clusteredSong_150D_10.mat') % very bad when settings same as below
  38. % load('genData100D_10_big.mat') %- MAYBE WORKS WELL?
  39. % load('genData100D_10_small.mat') % excellent
  40. % load('genData100D_25c_100ppc.mat')
  41. load('genData100D_10c_250ppc.mat')
  42. % load('genData_realSpec_try2.mat')
  43. % load('genData_realSpec_5c.mat')
  44. bigY = [allPts_dm'];%_unshuff'];% allPts_dm_unshuff'];% allPts_dm' allPts_dm'];%(1:500,:)'];% allPts_dm'];
  45. mmY = [allPts'];%_unshuff'];% allPts_unshuff'];% allPts' allPts'];%(1:500,:)'];% allPts'];
  46. trueIDs = [0];%[clustIDs];
  47. % PCA CONFIGS
  48. pcaNet.changeThresh = 1e-4; % for output convergence when pcanet is goin
  49. initWeightMag = 0.01;
  50. pcaNet.capW = 25;%500; %max weight to each output cell
  51. pcaNet.inhibCap = 25; %NEED 10 FOR SYNTHETIC, 25 FOR REAL...
  52. pcaNet.etaW = 0.01;
  53. pcaNet.etaM = 0.01;
  54. pcaNet.maxM = 1;
  55. pcaNet.maxW = 1;
  56. % MISMATCH CONFIGS
  57. mmNet.eta = 0.1;
  58. mmNet.thresh = 0.05; %below this, mm activity is zero (single cell)
  59. % MISTMATCH STRUCTURE CONFIGS
  60. mmNet.signed_synapses = 1; %force positive weight coeffs
  61. mmNet.c_plastic = 1; %
  62. mmNet.y_plastic = 0;
  63. % MISMATCH ARCHITECTURE CONFIGS
  64. yE_type = 'rand'; %'rand' or 'randnorm'
  65. cE_type = 'rand';
  66. yI_type = 'rand';
  67. cI_type = 'rand';
  68. yR_type = 'direct'; %'rand' or 'direct'
  69. cR_type = 'direct'; %'rand' or 'direct'
  70. % MISMATCH NETWORK CONFIGS
  71. nCells_y = size(mmY,1); %one input per input dim
  72. nCells_c = 25; % (500 works on synthetic) 10 clusters in 100D case (nb 100d is a lot) (THIS EQUALS SIZE OF PCA NETWORK)
  73. nCells_Ny = 100; %not super sure how high or low this needs to be
  74. nCells_Nc = 100; %but 100 and 100 should be enough to give us convex cone
  75. iterations = size(bigY,2);
  76. % Total Configs
  77. sigmaThresh = 25;%25; %1 worked well
  78. deltaLearn = 0.25; %(0.1) % no buffer -> incorrect clustering
  79. learningThresh = 0; %always learning -> incorrect clustering
  80. pcaNet.sigmaThresh = sigmaThresh;
  81. % last things - STARTING THESE AT 0 IS ACTUALLY IMPORTANT
  82. pcaNet.learning = 0; %this is set to 0 if mismatch is low, 1 if high
  83. learningSig = 0;
  84. rng(seed)
  85. %% PCA - Setup
  86. W_init = initWeightMag*randn(nCells_c, nCells_y);
  87. M_init = 0*initWeightMag*rand(nCells_c); %initially 0
  88. for idx = 1:nCells_c
  89. M_init(idx, idx) = 0; %nrns don't drive selves
  90. end
  91. pcaNet.trueIDs = trueIDs;
  92. D_init = zeros(nCells_c,1);
  93. pcaNet.bigC = zeros(nCells_c, iterations);
  94. pcaNet.W = W_init;
  95. pcaNet.M = M_init;
  96. pcaNet.D = D_init;
  97. c = updateC_v5(...
  98. pcaNet.W, pcaNet.M, bigY(:,1), pcaNet.changeThresh);
  99. pcaNet.bigC(:,1) = c;
  100. [maxx, thisCluster] = max(c);
  101. if maxx > 0
  102. pcaNet.clusters = thisCluster;
  103. else
  104. pcaNet.clusters = 0;
  105. end
  106. cT = c; % c at timestep t
  107. y = bigY(:,1);
  108. rng(seed);
  109. %% Mismatch Setup
  110. switch yE_type
  111. case 'rand'
  112. mmNet.we_yn = (rand(nCells_Ny, nCells_y))./nCells_y;
  113. case 'randnorm'
  114. mmNet.we_yn = (randn(nCells_Ny, nCells_y))./nCells_y;
  115. end
  116. switch yI_type
  117. case 'rand'
  118. mmNet.wi_yn = (rand(nCells_Nc, nCells_y))./nCells_y;
  119. case 'randnorm'
  120. mmNet.wi_yn = (randn(nCells_Nc, nCells_y))./nCells_y;
  121. end
  122. switch yR_type
  123. case 'rand'
  124. mmNet.r_yn = rand(nCells_y);
  125. case 'direct'
  126. mmNet.r_yn = eye(nCells_y);
  127. end
  128. switch cE_type
  129. case 'rand'
  130. mmNet.we_cn = (rand(nCells_Nc, nCells_c));%./nCells_c;
  131. %
  132. case 'randnorm'
  133. mmNet.we_cn = (randn(nCells_Nc, nCells_c));%./nCells_c;
  134. end
  135. switch cI_type
  136. case 'rand'
  137. mmNet.wi_cn = (rand(nCells_Ny, nCells_c));%./nCells_c;
  138. %
  139. case 'randnorm'
  140. mmNet.wi_cn = (randn(nCells_Ny, nCells_c));%./nCells_c;
  141. end
  142. switch cR_type
  143. case 'rand'
  144. mmNet.r_cn = rand(nCells_c);
  145. case 'direct'
  146. mmNet.r_cn = eye(nCells_c);
  147. end
  148. timesteps = length(bigY(1,:));
  149. if trackVars == 1
  150. mmNet.Vs_y = zeros(nCells_Ny, timesteps);
  151. mmNet.Vs_c = zeros(nCells_Nc, timesteps);
  152. mmNet.Fs_y = zeros(nCells_Ny, timesteps);
  153. mmNet.Fs_c = zeros(nCells_Nc, timesteps);
  154. mmNet.yWs_e = zeros(nCells_Ny, nCells_y, timesteps+1);
  155. mmNet.cWs_e = zeros(nCells_Nc, nCells_c, timesteps+1);
  156. mmNet.yWs_i = zeros(nCells_Nc, nCells_y, timesteps+1);
  157. mmNet.cWs_i = zeros(nCells_Ny, nCells_c, timesteps+1);
  158. mmNet.wyChanges_e = zeros(nCells_Ny, nCells_y, timesteps);
  159. mmNet.wcChanges_e = zeros(nCells_Nc, nCells_c, timesteps);
  160. mmNet.wyChanges_i = zeros(nCells_Nc, nCells_y, timesteps);
  161. mmNet.wcChanges_i = zeros(nCells_Ny, nCells_c, timesteps);
  162. mmNet.yWs_e(:,:,1) = mmNet.we_yn;
  163. mmNet.cWs_e(:,:,1) = mmNet.we_cn;
  164. mmNet.yWs_i(:,:,1) = mmNet.wi_yn;
  165. mmNet.cWs_i(:,:,1) = mmNet.wi_cn;
  166. end
  167. mmNet.errors_y = zeros(timesteps, 1);
  168. mmNet.errors_c = zeros(timesteps, 1);
  169. mmNet.allErrors = zeros(timesteps,1);
  170. % we've already done 1 iter of pca, so here do 1 iter of mm
  171. mmNet = mismatchIter_v2(cT, mmY, 1, mmNet, trackVars);
  172. disp('setup complete')
  173. %% Run through algorithm (make functions for PCAIter and MNIter)
  174. for ts_idx = 2:iterations
  175. pcaNet = pcaIter_v6(bigY, ts_idx, pcaNet);
  176. pcaC = pcaNet.bigC(:, ts_idx);
  177. pcaC = pcaC>eps;
  178. mmNet = mismatchIter_v2(pcaC, mmY, ts_idx, mmNet, trackVars);
  179. sigmaNode = mmNet.allErrors(ts_idx);
  180. if sigmaNode > sigmaThresh
  181. learningSig = learningSig + deltaLearn;
  182. if learningSig > 1
  183. learningSig = 1;
  184. end
  185. else
  186. learningSig = learningSig - deltaLearn;
  187. if learningSig < 0
  188. learningSig = 0;
  189. end
  190. end
  191. if learningSig >= learningThresh
  192. pcaNet.learning = 1;
  193. else
  194. pcaNet.learning = 0;
  195. end
  196. if errorPlotting == 1
  197. agnoPlotting_100D(bigY, mmNet, pcaNet, ts_idx, iterations)
  198. end
  199. end
  200. %% Plotting
  201. % Cluster ID Distribution
  202. figure(2)
  203. subplot(2,2,[1 2])
  204. histogram(pcaNet.clusters(1:iterations), [-0.5:1:nCells_c+0.5])
  205. % hold on;
  206. % rl = refline(0, 250);
  207. % rl.Color = 'r';
  208. % rl.LineWidth = 2;
  209. % xlim([0 500])
  210. title('Inputs per Cluster - Total Dataset')
  211. xlabel('Cluster ID')
  212. ylabel('Number of Inputs')
  213. % set(gca, 'fontsize', 18)
  214. % hold off;
  215. subplot(2,2,3)
  216. histogram(pcaNet.clusters(1:iterations/2), [-0.5:1:nCells_c+0.5])
  217. title('Inputs per Cluster - First Half')
  218. xlabel('Cluster ID')
  219. ylabel('Number of Inputs')
  220. subplot(2,2,4)
  221. histogram(pcaNet.clusters((iterations/2)+1:end), [-0.5:1:nCells_c+0.5])
  222. title('Inputs per Cluster - Second Half')
  223. xlabel('Cluster ID')
  224. ylabel('Number of Inputs')
  225. figure(81)
  226. plot(mmNet.allErrors(1:ts_idx), 'r-');
  227. % hold on;
  228. % refline(0, pcaNet.sigmaThresh);
  229. % hold off;
  230. title(['Total Error over Time - ts: ' num2str(ts_idx)])
  231. xlabel('Timestep')
  232. ylabel('Error')
  233. xlim([0 iterations])
  234. ylim([0 1000])
  235. pause(1e-5)
  236. figure(999)
  237. silhouette(bigY', pcaNet.clusters);
  238. % Check Cluster Accuracy
  239. % this is actually not horrible... it would be nice to make it prettier or
  240. % something
  241. % toCheck = unique(pcaNet.clusters);
  242. %
  243. % for checkIdx = 1:length(toCheck)
  244. % thisCluster = toCheck(checkIdx);
  245. % memIDs = find(pcaNet.clusters == thisCluster);
  246. % memTruth = trueIDs(memIDs);
  247. % figure(1000 + checkIdx);
  248. % histogram(memTruth)
  249. % end