vnormalize.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. /*Copyright (C) 2005, 2006, 2007 Frank Michler, Philipps-University Marburg, Germany
  2. This program is free software; you can redistribute it and/or
  3. modify it under the terms of the GNU General Public License
  4. as published by the Free Software Foundation; either version 2
  5. of the License, or (at your option) any later version.
  6. This program is distributed in the hope that it will be useful,
  7. but WITHOUT ANY WARRANTY; without even the implied warranty of
  8. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  9. GNU General Public License for more details.
  10. You should have received a copy of the GNU General Public License
  11. along with this program; if not, write to the Free Software
  12. Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
  13. */
  14. #include "sys.hpp" // for libcwd
  15. #include "debug.hpp" // for libcwd
  16. #include "vnormalize.hpp"
  17. ////////Normalize//////////////////////////////////
  18. VecNormalize::VecNormalize()
  19. : AbstractNormalize(), Target(0), NTarget(0), NSource(0)
  20. {
  21. }
  22. int VecNormalize::AddConnection(VecConnection* newcon)
  23. {
  24. if (Target == 0)
  25. {
  26. Target = newcon->GetTargetLayer();
  27. NTarget = Target->N;
  28. }
  29. if (Target == newcon->GetTargetLayer()) ConList.push_back(newcon);
  30. else {
  31. cerr << "ERROR: Target-Layer not the same\n";
  32. exit (1);
  33. }
  34. NSource += newcon->GetNSource();
  35. if (RewiringOn) {
  36. newcon->SetRewiringOff(); // if VecNormalize object does the rewiring the connection object should't do ot too!!
  37. }
  38. }
  39. int VecNormalize::WriteSimInfo(fstream &fw)
  40. {
  41. stringstream sstr;
  42. sstr << "<Target id=\"" << Target->IdNumber << "\"/> \n";
  43. SimElement::WriteSimInfo(fw, sstr.str());
  44. }
  45. int VecNormalize::WriteSimInfo(fstream &fw, const string &ChildInfo)
  46. {
  47. stringstream sstr;
  48. sstr << "<Target id=\"" << Target->IdNumber << "\"/> \n";
  49. sstr << "<RewiringOn value=\"" << RewiringOn << "\"/> \n";
  50. sstr << "<IncommingConnectivity value=\"" << IncommingConnectivity << "\"/> \n";
  51. sstr << "<SynDelThreshold value=\"" << SynDelThreshold << "\"/> \n";
  52. sstr << "<InitialWeights value=\"" << InitialWeights << "\"/>\n";
  53. sstr << ChildInfo;
  54. SimElement::WriteSimInfo(fw, sstr.str());
  55. }
  56. void VecNormalize::SetRewiring(float _SynDelThreshold, float _IncommingConnectivity, float _InitialWeights)
  57. {
  58. RewiringOn=true;
  59. IncommingConnectivity=_IncommingConnectivity;
  60. SynDelThreshold=_SynDelThreshold;
  61. InitialWeights=_InitialWeights;
  62. // turn off rewiring in VecConnection objects,
  63. // because now the normalization object handles the rewiring
  64. for(vector<VecConnection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  65. {
  66. (*it)->SetRewiringOff();
  67. }
  68. }
  69. int VecNormalize::Rewire()
  70. {
  71. int tar=0;
  72. int NCon=ConList.size();
  73. // delete low weights
  74. int NDeletedSynapses=0;
  75. for(vector<VecConnection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  76. {
  77. NDeletedSynapses += (*it)->DeleteLowWeights(SynDelThreshold);
  78. }
  79. cout << "VecNormalize::Rewire() TotalDeletedSynapses=" << NDeletedSynapses << "\n";
  80. int NMaxWeights = int(round(IncommingConnectivity*NSource));
  81. vector<int> NNewWeights(NTarget);
  82. for (tar=0;tar<NTarget;++tar) {
  83. NNewWeights[tar] = NMaxWeights;
  84. }
  85. vector<vector<int> > NFreeWeights(NCon,vector<int>(NTarget));
  86. vector<vector<int> > NSglConNewWeights(NCon,vector<int>(NTarget));
  87. // count incomming weights
  88. vector<int> TotalFreeWeights(NTarget);
  89. for (int ConNr=0;ConNr<NCon;++ConNr) {
  90. int CurNs=(ConList[ConNr])->ns;
  91. for (int tar=0;tar<NTarget;++tar) {
  92. int PreSynSize = ((ConList[ConNr])->PreSynNr[tar]).size();
  93. NNewWeights[tar] -= PreSynSize;
  94. NFreeWeights[ConNr][tar] = CurNs-PreSynSize;
  95. TotalFreeWeights[tar] += NFreeWeights[ConNr][tar];
  96. }
  97. }
  98. if (NDeletedSynapses >0) {
  99. for (tar=0;tar<NTarget;++tar) {
  100. if (NNewWeights[tar]>0) {
  101. cout << "NNewWeights["<<tar<<"]=" << NNewWeights[tar] << "\n";
  102. cout << "TotalFreeWeights[" <<tar << "]=" << TotalFreeWeights[tar] << "\n";
  103. }
  104. }
  105. }
  106. // NNewWeights auf connections aufteilen
  107. for (tar=0;tar<NTarget;++tar) {
  108. if (NNewWeights[tar] >0) {
  109. float CurNewWeights = float(NNewWeights[tar])/TotalFreeWeights[tar];
  110. for(int ConNr=0;ConNr<NCon;++ConNr) {
  111. NSglConNewWeights[ConNr][tar] = int(CurNewWeights*NFreeWeights[ConNr][tar]);
  112. NNewWeights[tar] -= NSglConNewWeights[ConNr][tar];
  113. }
  114. if (NNewWeights[tar] >= NCon) {
  115. cerr << "fatal ERROR: NNewWeights[tar] should be less than NConnections! (exiting)\n";
  116. exit(1);
  117. }
  118. while (NNewWeights[tar]>0) {
  119. int Winner = gsl_rng_uniform_int(gslr, NCon);
  120. ++NSglConNewWeights[Winner][tar];
  121. --NNewWeights[tar];
  122. }
  123. }
  124. }
  125. // setting new weights
  126. for (int ConNr=0;ConNr<NCon;++ConNr) {
  127. ConList[ConNr]->SetNewWeights(&(NSglConNewWeights[ConNr]), InitialWeights);
  128. }
  129. }
  130. int VecNormalize::prepare(int Step)
  131. {
  132. if (RewiringOn) Rewire();
  133. }
  134. //////////////////////////////
  135. /*! \brief normalizing synapit weights if firing rates are above a threshold
  136. NormFrequency: above this spike frequency normalization occurs
  137. Weights are multiplied with (1-NormFactor)
  138. if spike frequency is higher then weight reduction is larger
  139. maximum: (1-NormFactor*MaxNormFactor)
  140. a look up table is used to determine the current normalization factor,
  141. depending on the current spike frequency (time difference DeltaT between current spike and last spike):
  142. NormLut(DeltaT)=1-MaxNormFactor*NormFactor*exp(-DeltaT/Tau),
  143. weight is multiplied with NormLut(DeltaT)
  144. @param _NormFrequency threshold frequency, above this frequency normalization occurs
  145. @param _NormFactor weight normalization factor,
  146. if postsynaptic neuron fires with _NormFrequency, synaptic weight is mulitiplied with 1-_NormFactor
  147. @param _MaxNormFactor maximal normalization: (1-NormFactor*MaxNormFactor)
  148. If postsynaptic neuron fires with frequency higher than _NormFrequency, the normalization is stronger.
  149. For infinite firing rate normalization strength can raise up to _MaxNormFactor times.
  150. @author (fm)
  151. */
  152. VecFiringRateNormalize2::VecFiringRateNormalize2(
  153. float _NormFrequency, float _NormFactor, float _MaxNormFactor)
  154. :PostSynLastFirings(0),
  155. MaxNormFactor(_MaxNormFactor),
  156. NormFactor(_NormFactor),
  157. NormFrequency(_NormFrequency)
  158. {
  159. NormDeltaT = 1000./(NormFrequency*dt);
  160. Tau = NormDeltaT/log(MaxNormFactor);
  161. NormLut = ExpDecayLut(NormLutN, Tau, -MaxNormFactor*NormFactor, 1, dt, NormDeltaT/Tau);
  162. cout << "VecFiringRateNormalization2\n";
  163. cout << "NormLut=" << "\n";
  164. for (int i=0;i<NormLutN;++i) cout << NormLut[i] << "\n";
  165. cout << " Factor=" << NormFactor << " MaxNormFactor=" << MaxNormFactor << " Tau=" << Tau*dt << " ms\n";
  166. }
  167. int VecFiringRateNormalize2::WriteSimInfo(fstream &fw)
  168. {
  169. stringstream sstr;
  170. sstr << "<MaxNormFactor value=\"" << MaxNormFactor << "\"/>\n";
  171. sstr << "<NormFrequency value=\"" << NormFrequency << "\"/>\n";
  172. sstr << "<NormFactor value=\"" << NormFactor << "\"/>\n";
  173. VecNormalize::WriteSimInfo(fw, sstr.str());
  174. }
  175. int VecFiringRateNormalize2::AddConnection(VecConnection* newcon)
  176. {
  177. VecNormalize::AddConnection(newcon);
  178. if (PostSynLastFirings == 0)
  179. {
  180. PostSynLastFirings = new int [NTarget];
  181. int i;
  182. for (i=0;i<NTarget;++i)
  183. {
  184. PostSynLastFirings[i]=0;
  185. }
  186. }
  187. }
  188. int VecFiringRateNormalize2::proceede(int TotalTime)
  189. {
  190. int t = int(TotalTime % MacroTimeStep);
  191. int spike = Target->last_N_firings;
  192. int CurTarget;
  193. int i,j, TDiff;
  194. float NormFactor=1;
  195. while (spike < Target->N_firings) {
  196. CurTarget = Target->firings[spike][1];
  197. int TDiff = t-PostSynLastFirings[CurTarget];
  198. if (TDiff < NormLutN) {
  199. NormFactor = NormLut[TDiff];
  200. for(vector<VecConnection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  201. {
  202. (*it)->MultiplyTargetWeights(CurTarget, NormFactor);
  203. }
  204. }
  205. PostSynLastFirings[CurTarget]=t;
  206. ++spike;
  207. }
  208. }
  209. int VecFiringRateNormalize2::prepare(int Step)
  210. {
  211. VecNormalize::prepare(Step);
  212. int i;
  213. cout << "shifting lastspikes\n"; fflush(stdout);
  214. for (i=0;i<NTarget;++i) PostSynLastFirings[i] -= MacroTimeStep;
  215. //FixMe: what if neuron never fires?? prevent negative integer overflow??
  216. cout << "vnorm::prepared\n"; fflush(stdout);
  217. }
  218. //////////////////////////////
  219. VecConstSumNormalize::VecConstSumNormalize(float _WeightSum, bool _quadratic)
  220. : WeightSum(_WeightSum), quadratic(_quadratic)
  221. {
  222. cout << " VecConstSumNormalization\n";
  223. cout << " WeightSum=" << WeightSum << " quadratic=" << quadratic << " \n";
  224. }
  225. int VecConstSumNormalize::prepare(int Step)
  226. {
  227. }
  228. int VecConstSumNormalize::proceede(int TotalTime)
  229. {
  230. int t = int(TotalTime % MacroTimeStep);
  231. int spike = Target->last_N_firings;
  232. int CurTarget;
  233. int i,j;
  234. float CurWeightSum, NormFactor;
  235. float tmpweight;
  236. if (quadratic)
  237. {
  238. cerr << " VecConstSumNormalize::proceede/QUADRATIC noch nicht implementiert}\n";
  239. exit(1);
  240. // while (spike < Target->N_firings) {
  241. // CurTarget = Target->firings[spike][1];
  242. // // calculate WeightSum
  243. // CurWeightSum=0;
  244. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  245. // {
  246. // for (i=0;i<(*it)->N_pre[CurTarget];++i) {
  247. // tmpweight= (*((*it)->s_pre[CurTarget][i]));
  248. // if (tmpweight <0) { // delete this thread
  249. // cout << "EEEEEEEEEERRRROOORRR, weight deletion didn't work\n";
  250. // fflush(stdout);
  251. // exit(2);
  252. // }
  253. // CurWeightSum += (*((*it)->s_pre[CurTarget][i]))*(*((*it)->s_pre[CurTarget][i]));
  254. // }
  255. // }
  256. // // DEBUG
  257. // if (CurWeightSum > 100) {
  258. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  259. // {
  260. // cout << "N_pre=" << (*it)->N_pre[CurTarget] << "\n";
  261. // for (i=0;i<(*it)->N_pre[CurTarget];++i) {
  262. // tmpweight= (*((*it)->s_pre[CurTarget][i]));
  263. // cout << "w"<< i << "=" << tmpweight << "I_pre=" << (*it)->I_pre[CurTarget][i] << "\n";
  264. // }
  265. // }
  266. // fflush(stdout);
  267. // exit(2);
  268. // }
  269. // // END DEBUG
  270. // NormFactor = WeightSum/sqrt(CurWeightSum);
  271. // cout << "NormFactor=" << NormFactor << "WeightSum" << CurWeightSum << "\n";
  272. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  273. // {
  274. // for (i=0;i<(*it)->N_pre[CurTarget];++i) *((*it)->s_pre[CurTarget][i]) *= NormFactor;
  275. // }
  276. // ++spike;
  277. // }
  278. } else {
  279. while (spike < Target->N_firings) {
  280. CurTarget = Target->firings[spike][1];
  281. // calculate WeightSum
  282. CurWeightSum=0;
  283. int Npre;
  284. for(TVecConnectionList::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  285. {
  286. // Npre=((*it)->PreSynNr)[CurTarget].size();
  287. // for (i=0;i<Npre;++i) CurWeightSum += ((*it)->SynWeights) [((*it)->PreSynNr)[CurTarget][i]];
  288. CurWeightSum += (*it)->GetWeightSum(CurTarget, false);
  289. }
  290. NormFactor = WeightSum/CurWeightSum;
  291. // multiplicatively normalize weights
  292. for(TVecConnectionList::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  293. {
  294. (*it)->MultiplyTargetWeights(CurTarget, NormFactor);
  295. }
  296. ++spike;
  297. }
  298. }
  299. }
  300. int VecConstSumNormalize::NormalizeAll()
  301. {
  302. cout << "Normalizing All ...";
  303. int CurTarget;
  304. int i,j;
  305. float CurWeightSum, NormFactor;
  306. if (quadratic)
  307. {
  308. cerr << " VecConstSumNormalize::proceede/QUADRATIC noch nicht implementiert}\n";
  309. exit(1);
  310. // for (CurTarget=0;CurTarget<Target->N;++CurTarget)
  311. // {
  312. // CurWeightSum=0; // wie oben
  313. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  314. // {
  315. // for (i=0;i<(*it)->N_pre[CurTarget];++i) CurWeightSum += (*((*it)->s_pre[CurTarget][i]))*(*((*it)->s_pre[CurTarget][i]));
  316. // }
  317. // NormFactor = WeightSum/sqrt(CurWeightSum);
  318. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  319. // {
  320. // for (i=0;i<(*it)->N_pre[CurTarget];++i) *((*it)->s_pre[CurTarget][i]) *= NormFactor;
  321. // }
  322. // }
  323. }
  324. else {
  325. for (CurTarget=0;CurTarget<Target->N;++CurTarget)
  326. {
  327. CurWeightSum=0;
  328. // calculate Weight Sum
  329. for(TVecConnectionList::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  330. {
  331. CurWeightSum += (*it)->GetWeightSum(CurTarget, false);
  332. }
  333. NormFactor = WeightSum/CurWeightSum;
  334. // multiplicatively normalize weights
  335. for(TVecConnectionList::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  336. {
  337. (*it)->MultiplyTargetWeights(CurTarget, NormFactor);
  338. }
  339. }
  340. }
  341. cout << " [done]\n";
  342. }
  343. // ToDo: diese Funktion testen
  344. // setzt die Gewichtssumme entsprechend des Durchschnitts aller Gewichte
  345. void VecConstSumNormalize::CalcInitWeightSum()
  346. {
  347. cout << "Calculating Initial Weight Sum ...";
  348. cerr << "VecConstSumNormalize::CalcInitWeightSum() is not tested\n";
  349. exit(1);
  350. int CurTarget;
  351. int i,j;
  352. float CurWeightSum, NormFactor;
  353. float TmpWeightSum=0;
  354. if (quadratic)
  355. {
  356. cerr << " VecConstSumNormalize::proceede/QUADRATIC noch nicht implementiert}\n";
  357. exit(1);
  358. // for (CurTarget=0;CurTarget<Target->N;++CurTarget)
  359. // {
  360. // CurWeightSum=0; // wie oben
  361. // for(vector<connection*>::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  362. // {
  363. // for (i=0;i<(*it)->N_pre[CurTarget];++i) CurWeightSum += (*((*it)->s_pre[CurTarget][i]))*(*((*it)->s_pre[CurTarget][i]));
  364. // }
  365. // TmpWeightSum += sqrt(CurWeightSum);
  366. // }
  367. // WeightSum = TmpWeightSum/Target->N;
  368. }
  369. else {
  370. for (CurTarget=0;CurTarget<Target->N;++CurTarget)
  371. {
  372. CurWeightSum=0;
  373. // calculate Weight Sum
  374. for(TVecConnectionList::iterator it=ConList.begin(); it !=ConList.end(); ++it)
  375. {
  376. CurWeightSum += (*it)->GetWeightSum(CurTarget, false);
  377. }
  378. TmpWeightSum += CurWeightSum;
  379. }
  380. WeightSum = TmpWeightSum/Target->N;
  381. }
  382. cout << " [done]\n";
  383. }
  384. float VecConstSumNormalize::GetWeightSum()
  385. {
  386. return WeightSum;
  387. }
  388. int VecConstSumNormalize::SetWeightSum(float NewWeightSum)
  389. {
  390. WeightSum=NewWeightSum;
  391. }
  392. int VecConstSumNormalize::WriteSimInfo(fstream &fw)
  393. {
  394. stringstream sstr;
  395. sstr << "<WeightSum Value=\"" << WeightSum << "\"/> \n";
  396. sstr << "<Quadratic Value=\"" << quadratic << "\"/> \n";
  397. VecNormalize::WriteSimInfo(fw, sstr.str());
  398. }
  399. //////////////////////////////