Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

generate_datasets.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Generate datasets for testing
  4. '''
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. from datetime import datetime
  8. import numpy as np
  9. from numpy.random import rand
  10. import quantities as pq
  11. from neo.core import (AnalogSignal,
  12. Block,
  13. Epoch, Event,
  14. IrregularlySampledSignal,
  15. ChannelIndex,
  16. Segment, SpikeTrain,
  17. Unit,
  18. class_by_name)
  19. from neo.io.tools import iteritems
  20. from neo.core.baseneo import _container_name
  21. TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test",
  22. datetime.fromtimestamp(424242424), None]
  23. def generate_one_simple_block(block_name='block_0', nb_segment=3,
  24. supported_objects=[], **kws):
  25. if supported_objects and Block not in supported_objects:
  26. raise ValueError('Block must be in supported_objects')
  27. bl = Block() # name = block_name)
  28. objects = supported_objects
  29. if Segment in objects:
  30. for s in range(nb_segment):
  31. seg = generate_one_simple_segment(seg_name="seg" + str(s),
  32. supported_objects=objects, **kws)
  33. bl.segments.append(seg)
  34. #if RecordingChannel in objects:
  35. # populate_RecordingChannel(bl)
  36. bl.create_many_to_one_relationship()
  37. return bl
  38. def generate_one_simple_segment(seg_name='segment 0',
  39. supported_objects=[],
  40. nb_analogsignal=4,
  41. t_start=0.*pq.s,
  42. sampling_rate=10*pq.kHz,
  43. duration=6.*pq.s,
  44. nb_spiketrain=6,
  45. spikerate_range=[.5*pq.Hz, 12*pq.Hz],
  46. event_types={'stim': ['a', 'b',
  47. 'c', 'd'],
  48. 'enter_zone': ['one',
  49. 'two'],
  50. 'color': ['black',
  51. 'yellow',
  52. 'green'],
  53. },
  54. event_size_range=[5, 20],
  55. epoch_types={'animal state': ['Sleep',
  56. 'Freeze',
  57. 'Escape'],
  58. 'light': ['dark',
  59. 'lighted']
  60. },
  61. epoch_duration_range=[.5, 3.],
  62. ):
  63. if supported_objects and Segment not in supported_objects:
  64. raise ValueError('Segment must be in supported_objects')
  65. seg = Segment(name=seg_name)
  66. if AnalogSignal in supported_objects:
  67. for a in range(nb_analogsignal):
  68. anasig = AnalogSignal(rand(int(sampling_rate * duration)),
  69. sampling_rate=sampling_rate, t_start=t_start,
  70. units=pq.mV, channel_index=a,
  71. name='sig %d for segment %s' % (a, seg.name))
  72. seg.analogsignals.append(anasig)
  73. if SpikeTrain in supported_objects:
  74. for s in range(nb_spiketrain):
  75. spikerate = rand()*np.diff(spikerate_range)
  76. spikerate += spikerate_range[0].magnitude
  77. #spikedata = rand(int((spikerate*duration).simplified))*duration
  78. #sptr = SpikeTrain(spikedata,
  79. # t_start=t_start, t_stop=t_start+duration)
  80. # #, name = 'spiketrain %d'%s)
  81. spikes = rand(int((spikerate*duration).simplified))
  82. spikes.sort() # spikes are supposed to be an ascending sequence
  83. sptr = SpikeTrain(spikes*duration,
  84. t_start=t_start, t_stop=t_start+duration)
  85. sptr.annotations['channel_index'] = s
  86. seg.spiketrains.append(sptr)
  87. if Event in supported_objects:
  88. for name, labels in iteritems(event_types):
  89. evt_size = rand()*np.diff(event_size_range)
  90. evt_size += event_size_range[0]
  91. evt_size = int(evt_size)
  92. labels = np.array(labels, dtype='S')
  93. labels = labels[(rand(evt_size)*len(labels)).astype('i')]
  94. evt = Event(times=rand(evt_size)*duration, labels=labels)
  95. seg.events.append(evt)
  96. if Epoch in supported_objects:
  97. for name, labels in iteritems(epoch_types):
  98. t = 0
  99. times = []
  100. durations = []
  101. while t < duration:
  102. times.append(t)
  103. dur = rand()*np.diff(epoch_duration_range)
  104. dur += epoch_duration_range[0]
  105. durations.append(dur)
  106. t = t+dur
  107. labels = np.array(labels, dtype='S')
  108. labels = labels[(rand(len(times))*len(labels)).astype('i')]
  109. epc = Epoch(times=pq.Quantity(times, units=pq.s),
  110. durations=pq.Quantity([x[0] for x in durations],
  111. units=pq.s),
  112. labels=labels,
  113. )
  114. seg.epochs.append(epc)
  115. # TODO : Spike, Event
  116. seg.create_many_to_one_relationship()
  117. return seg
  118. def generate_from_supported_objects(supported_objects):
  119. #~ create_many_to_one_relationship
  120. if not supported_objects:
  121. raise ValueError('No objects specified')
  122. objects = supported_objects
  123. if Block in supported_objects:
  124. higher = generate_one_simple_block(supported_objects=objects)
  125. # Chris we do not create RC and RCG if it is not in objects
  126. # there is a test in generate_one_simple_block so I removed
  127. #finalize_block(higher)
  128. elif Segment in objects:
  129. higher = generate_one_simple_segment(supported_objects=objects)
  130. else:
  131. #TODO
  132. return None
  133. higher.create_many_to_one_relationship()
  134. return higher
  135. def get_fake_value(name, datatype, dim=0, dtype='float', seed=None,
  136. units=None, obj=None, n=None, shape=None):
  137. """
  138. Returns default value for a given attribute based on neo.core
  139. If seed is not None, use the seed to set the random number generator.
  140. """
  141. if not obj:
  142. obj = 'TestObject'
  143. elif not hasattr(obj, 'lower'):
  144. obj = obj.__name__
  145. if (name in ['name', 'file_origin', 'description'] and
  146. (datatype != str or dim)):
  147. raise ValueError('%s must be str, not a %sD %s' % (name, dim,
  148. datatype))
  149. if name == 'file_origin':
  150. return 'test_file.txt'
  151. if name == 'name':
  152. return '%s%s' % (obj, get_fake_value('', datatype, seed=seed))
  153. if name == 'description':
  154. return 'test %s %s' % (obj, get_fake_value('', datatype, seed=seed))
  155. if seed is not None:
  156. np.random.seed(seed)
  157. if datatype == str:
  158. return str(np.random.randint(100000))
  159. if datatype == int:
  160. return np.random.randint(100)
  161. if datatype == float:
  162. return 1000. * np.random.random()
  163. if datatype == datetime:
  164. return datetime.fromtimestamp(1000000000*np.random.random())
  165. if (name in ['t_start', 't_stop', 'sampling_rate'] and
  166. (datatype != pq.Quantity or dim)):
  167. raise ValueError('%s must be a 0D Quantity, not a %sD %s' % (name, dim,
  168. datatype))
  169. # only put array types below here
  170. if units is not None:
  171. pass
  172. elif name in ['t_start', 't_stop',
  173. 'time', 'times',
  174. 'duration', 'durations']:
  175. units = pq.millisecond
  176. elif name == 'sampling_rate':
  177. units = pq.Hz
  178. elif datatype == pq.Quantity:
  179. units = np.random.choice(['nA', 'mA', 'A', 'mV', 'V'])
  180. units = getattr(pq, units)
  181. if name == 'sampling_rate':
  182. data = np.array(10000.0)
  183. elif name == 't_start':
  184. data = np.array(0.0)
  185. elif name == 't_stop':
  186. data = np.array(1.0)
  187. elif n and name == 'channel_indexes':
  188. data = np.arange(n)
  189. elif n and name == 'channel_names':
  190. data = np.array(["ch%d" % i for i in range(n)])
  191. elif n and obj == 'AnalogSignal':
  192. if name == 'signal':
  193. size = []
  194. for _ in range(int(dim)):
  195. size.append(np.random.randint(5) + 1)
  196. size[1] = n
  197. data = np.random.random(size)*1000.
  198. else:
  199. size = []
  200. for _ in range(int(dim)):
  201. if shape is None :
  202. if name == "times":
  203. size.append(5)
  204. else :
  205. size.append(np.random.randint(5) + 1)
  206. else:
  207. size.append(shape)
  208. data = np.random.random(size)
  209. if name not in ['time', 'times']:
  210. data *= 1000.
  211. if np.dtype(dtype) != np.float64:
  212. data = data.astype(dtype)
  213. if datatype == np.ndarray:
  214. return data
  215. if datatype == list:
  216. return data.tolist()
  217. if datatype == pq.Quantity:
  218. return data * units # set the units
  219. # we have gone through everything we know, so it must be something invalid
  220. raise ValueError('Unknown name/datatype combination %s %s' % (name,
  221. datatype))
  222. def get_fake_values(cls, annotate=True, seed=None, n=None):
  223. """
  224. Returns a dict containing the default values for all attribute for
  225. a class from neo.core.
  226. If seed is not None, use the seed to set the random number generator.
  227. The seed is incremented by 1 for each successive object.
  228. If annotate is True (default), also add annotations to the values.
  229. """
  230. if hasattr(cls, 'lower'): # is this a test that cls is a string? better to use isinstance(cls, basestring), no?
  231. cls = class_by_name[cls]
  232. kwargs = {} # assign attributes
  233. for i, attr in enumerate(cls._necessary_attrs + cls._recommended_attrs):
  234. if seed is not None:
  235. iseed = seed + i
  236. else:
  237. iseed = None
  238. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n)
  239. if 'waveforms' in kwargs : #everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[0]
  240. if len(kwargs["times"]) != kwargs["waveforms"].shape[0] :
  241. if len(kwargs["times"]) < kwargs["waveforms"].shape[0] :
  242. dif = kwargs["waveforms"].shape[0] - len(kwargs["times"])
  243. new_times =[]
  244. for i in kwargs["times"].magnitude :
  245. new_times.append(i)
  246. np.random.seed(0)
  247. new_times = np.concatenate([new_times, np.random.random(dif)])
  248. kwargs["times"] = pq.Quantity(new_times, units=pq.ms)
  249. else :
  250. kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]]
  251. if 'times' in kwargs and 'signal' in kwargs:
  252. kwargs['times'] = kwargs['times'][:len(kwargs['signal'])]
  253. kwargs['signal'] = kwargs['signal'][:len(kwargs['times'])]
  254. if annotate:
  255. kwargs.update(get_annotations())
  256. kwargs['seed'] = seed
  257. return kwargs
  258. def get_annotations():
  259. '''
  260. Returns a dict containing the default values for annotations for
  261. a class from neo.core.
  262. '''
  263. return dict([(str(i), ann) for i, ann in enumerate(TEST_ANNOTATIONS)])
  264. def fake_neo(obj_type="Block", cascade=True, seed=None, n=1):
  265. '''
  266. Create a fake NEO object of a given type. Follows one-to-many
  267. and many-to-many relationships if cascade.
  268. n (default=1) is the number of child objects of each type will be created.
  269. In cases like segment.spiketrains, there will be more than this number
  270. because there will be n for each unit, of which there will be n for
  271. each channelindex, of which there will be n.
  272. '''
  273. if hasattr(obj_type, 'lower'):
  274. cls = class_by_name[obj_type]
  275. else:
  276. cls = obj_type
  277. obj_type = obj_type.__name__
  278. kwargs = get_fake_values(obj_type, annotate=True, seed=seed, n=n)
  279. obj = cls(**kwargs)
  280. # if not cascading, we don't need to do any of the stuff after this
  281. if not cascade:
  282. return obj
  283. # this is used to signal other containers that they shouldn't duplicate
  284. # data
  285. if obj_type == 'Block':
  286. cascade = 'block'
  287. for i, childname in enumerate(getattr(obj, '_child_objects', [])):
  288. # we create a few of each class
  289. for j in range(n):
  290. if seed is not None:
  291. iseed = 10*seed+100*i+1000*j
  292. else:
  293. iseed = None
  294. child = fake_neo(obj_type=childname, cascade=cascade,
  295. seed=iseed, n=n)
  296. child.annotate(i=i, j=j)
  297. # if we are creating a block and this is the object's primary
  298. # parent, don't create the object, we will import it from secondary
  299. # containers later
  300. if (cascade == 'block' and len(child._parent_objects) > 0 and
  301. obj_type != child._parent_objects[-1]):
  302. continue
  303. getattr(obj, _container_name(childname)).append(child)
  304. # need to manually create 'implicit' connections
  305. if obj_type == 'Block':
  306. # connect data objects to segment
  307. for i, chx in enumerate(obj.channel_indexes):
  308. for k, sigarr in enumerate(chx.analogsignals):
  309. obj.segments[k].analogsignals.append(sigarr)
  310. for k, sigarr in enumerate(chx.irregularlysampledsignals):
  311. obj.segments[k].irregularlysampledsignals.append(sigarr)
  312. for j, unit in enumerate(chx.units):
  313. for k, train in enumerate(unit.spiketrains):
  314. obj.segments[k].spiketrains.append(train)
  315. #elif obj_type == 'ChannelIndex':
  316. # inds = []
  317. # names = []
  318. # chinds = np.array([unit.channel_indexes[0] for unit in obj.units])
  319. # obj.indexes = np.array(inds, dtype='i')
  320. # obj.channel_names = np.array(names).astype('S')
  321. if hasattr(obj, 'create_many_to_one_relationship'):
  322. obj.create_many_to_one_relationship()
  323. return obj
  324. def clone_object(obj, n=None):
  325. '''
  326. Generate a new object and new objects with the same rules as the original.
  327. '''
  328. if hasattr(obj, '__iter__') and not hasattr(obj, 'ndim'):
  329. return [clone_object(iobj, n=n) for iobj in obj]
  330. cascade = hasattr(obj, 'children') and len(obj.children)
  331. if n is not None:
  332. pass
  333. elif cascade:
  334. n = min(len(getattr(obj, cont)) for cont in obj._child_containers)
  335. else:
  336. n = 0
  337. seed = obj.annotations.get('seed', None)
  338. newobj = fake_neo(obj.__class__, cascade=cascade, seed=seed, n=n)
  339. if 'i' in obj.annotations:
  340. newobj.annotate(i=obj.annotations['i'], j=obj.annotations['j'])
  341. return newobj