Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

generate_datasets.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # -*- coding: utf-8 -*-
  2. '''
  3. Generate datasets for testing
  4. '''
  5. # needed for python 3 compatibility
  6. from __future__ import absolute_import
  7. from datetime import datetime
  8. import numpy as np
  9. from numpy.random import rand
  10. import quantities as pq
  11. from neo.core import (AnalogSignal, Block, Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  12. Segment, SpikeTrain, Unit, class_by_name)
  13. from neo.core.baseneo import _container_name
  14. from neo.core.dataobject import DataObject
  15. TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test", datetime.fromtimestamp(424242424), None]
  16. def generate_one_simple_block(block_name='block_0', nb_segment=3, supported_objects=[], **kws):
  17. if supported_objects and Block not in supported_objects:
  18. raise ValueError('Block must be in supported_objects')
  19. bl = Block() # name = block_name)
  20. objects = supported_objects
  21. if Segment in objects:
  22. for s in range(nb_segment):
  23. seg = generate_one_simple_segment(seg_name="seg" + str(s), supported_objects=objects,
  24. **kws)
  25. bl.segments.append(seg)
  26. # if RecordingChannel in objects:
  27. # populate_RecordingChannel(bl)
  28. bl.create_many_to_one_relationship()
  29. return bl
  30. def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_analogsignal=4,
  31. t_start=0. * pq.s, sampling_rate=10 * pq.kHz, duration=6. * pq.s,
  32. nb_spiketrain=6, spikerate_range=[.5 * pq.Hz, 12 * pq.Hz],
  33. event_types={'stim': ['a', 'b', 'c', 'd'],
  34. 'enter_zone': ['one', 'two'],
  35. 'color': ['black', 'yellow', 'green'], },
  36. event_size_range=[5, 20],
  37. epoch_types={'animal state': ['Sleep', 'Freeze', 'Escape'],
  38. 'light': ['dark', 'lighted']},
  39. epoch_duration_range=[.5, 3.],
  40. array_annotations={'valid': np.array([True, False]),
  41. 'number': np.array(range(5))}
  42. ):
  43. if supported_objects and Segment not in supported_objects:
  44. raise ValueError('Segment must be in supported_objects')
  45. seg = Segment(name=seg_name)
  46. if AnalogSignal in supported_objects:
  47. for a in range(nb_analogsignal):
  48. anasig = AnalogSignal(rand(int(sampling_rate * duration)), sampling_rate=sampling_rate,
  49. t_start=t_start, units=pq.mV, channel_index=a,
  50. name='sig %d for segment %s' % (a, seg.name))
  51. seg.analogsignals.append(anasig)
  52. if SpikeTrain in supported_objects:
  53. for s in range(nb_spiketrain):
  54. spikerate = rand() * np.diff(spikerate_range)
  55. spikerate += spikerate_range[0].magnitude
  56. # spikedata = rand(int((spikerate*duration).simplified))*duration
  57. # sptr = SpikeTrain(spikedata,
  58. # t_start=t_start, t_stop=t_start+duration)
  59. # #, name = 'spiketrain %d'%s)
  60. spikes = rand(int((spikerate * duration).simplified))
  61. spikes.sort() # spikes are supposed to be an ascending sequence
  62. sptr = SpikeTrain(spikes * duration, t_start=t_start, t_stop=t_start + duration)
  63. sptr.annotations['channel_index'] = s
  64. # Randomly generate array_annotations from given options
  65. arr_ann = {key: value[(rand(len(spikes)) * len(value)).astype('i')] for (key, value) in
  66. array_annotations.items()}
  67. sptr.array_annotate(**arr_ann)
  68. seg.spiketrains.append(sptr)
  69. if Event in supported_objects:
  70. for name, labels in event_types.items():
  71. evt_size = rand() * np.diff(event_size_range)
  72. evt_size += event_size_range[0]
  73. evt_size = int(evt_size)
  74. labels = np.array(labels, dtype='S')
  75. labels = labels[(rand(evt_size) * len(labels)).astype('i')]
  76. evt = Event(times=rand(evt_size) * duration, labels=labels)
  77. # Randomly generate array_annotations from given options
  78. arr_ann = {key: value[(rand(evt_size) * len(value)).astype('i')] for (key, value) in
  79. array_annotations.items()}
  80. evt.array_annotate(**arr_ann)
  81. seg.events.append(evt)
  82. if Epoch in supported_objects:
  83. for name, labels in epoch_types.items():
  84. t = 0
  85. times = []
  86. durations = []
  87. while t < duration:
  88. times.append(t)
  89. dur = rand() * np.diff(epoch_duration_range)
  90. dur += epoch_duration_range[0]
  91. durations.append(dur)
  92. t = t + dur
  93. labels = np.array(labels, dtype='S')
  94. labels = labels[(rand(len(times)) * len(labels)).astype('i')]
  95. assert len(times) == len(durations)
  96. assert len(times) == len(labels)
  97. epc = Epoch(times=pq.Quantity(times, units=pq.s),
  98. durations=pq.Quantity([x[0] for x in durations], units=pq.s),
  99. labels=labels, )
  100. # Randomly generate array_annotations from given options
  101. arr_ann = {key: value[(rand(len(times)) * len(value)).astype('i')] for (key, value) in
  102. array_annotations.items()}
  103. epc.array_annotate(**arr_ann)
  104. seg.epochs.append(epc)
  105. # TODO : Spike, Event
  106. seg.create_many_to_one_relationship()
  107. return seg
  108. def generate_from_supported_objects(supported_objects):
  109. # ~ create_many_to_one_relationship
  110. if not supported_objects:
  111. raise ValueError('No objects specified')
  112. objects = supported_objects
  113. if Block in supported_objects:
  114. higher = generate_one_simple_block(supported_objects=objects)
  115. # Chris we do not create RC and RCG if it is not in objects # there is a test in
  116. # generate_one_simple_block so I removed # finalize_block(higher)
  117. elif Segment in objects:
  118. higher = generate_one_simple_segment(supported_objects=objects)
  119. else:
  120. # TODO
  121. return None
  122. higher.create_many_to_one_relationship()
  123. return higher
  124. def get_fake_value(name, datatype, dim=0, dtype='float', seed=None, units=None, obj=None, n=None,
  125. shape=None):
  126. """
  127. Returns default value for a given attribute based on neo.core
  128. If seed is not None, use the seed to set the random number generator.
  129. """
  130. if not obj:
  131. obj = 'TestObject'
  132. elif not hasattr(obj, 'lower'):
  133. obj = obj.__name__
  134. if (name in ['name', 'file_origin', 'description'] and (datatype != str or dim)):
  135. raise ValueError('%s must be str, not a %sD %s' % (name, dim, datatype))
  136. if name == 'file_origin':
  137. return 'test_file.txt'
  138. if name == 'name':
  139. return '%s%s' % (obj, get_fake_value('', datatype, seed=seed))
  140. if name == 'description':
  141. return 'test %s %s' % (obj, get_fake_value('', datatype, seed=seed))
  142. if seed is not None:
  143. np.random.seed(seed)
  144. if datatype == str:
  145. return str(np.random.randint(100000))
  146. if datatype == int:
  147. return np.random.randint(100)
  148. if datatype == float:
  149. return 1000. * np.random.random()
  150. if datatype == datetime:
  151. return datetime.fromtimestamp(1000000000 * np.random.random())
  152. if (name in ['t_start', 't_stop', 'sampling_rate'] and (datatype != pq.Quantity or dim)):
  153. raise ValueError('%s must be a 0D Quantity, not a %sD %s' % (name, dim, datatype))
  154. # only put array types below here
  155. if units is not None:
  156. pass
  157. elif name in ['t_start', 't_stop', 'time', 'times', 'duration', 'durations']:
  158. units = pq.millisecond
  159. elif name == 'sampling_rate':
  160. units = pq.Hz
  161. elif datatype == pq.Quantity:
  162. units = np.random.choice(['nA', 'mA', 'A', 'mV', 'V'])
  163. units = getattr(pq, units)
  164. if name == 'sampling_rate':
  165. data = np.array(10000.0)
  166. elif name == 't_start':
  167. data = np.array(0.0)
  168. elif name == 't_stop':
  169. data = np.array(1.0)
  170. elif n and name == 'channel_indexes':
  171. data = np.arange(n)
  172. elif n and name == 'channel_names':
  173. data = np.array(["ch%d" % i for i in range(n)])
  174. elif n and obj == 'AnalogSignal':
  175. if name == 'signal':
  176. size = []
  177. for _ in range(int(dim)):
  178. size.append(np.random.randint(5) + 1)
  179. size[1] = n
  180. data = np.random.random(size) * 1000.
  181. else:
  182. size = []
  183. for _ in range(int(dim)):
  184. if shape is None:
  185. # To ensure consistency, times, labels and durations need to have the same size
  186. if name in ["times", "labels", "durations"]:
  187. size.append(5)
  188. else:
  189. size.append(np.random.randint(5) + 1)
  190. else:
  191. size.append(shape)
  192. data = np.random.random(size)
  193. if name not in ['time', 'times']:
  194. data *= 1000.
  195. if np.dtype(dtype) != np.float64:
  196. data = data.astype(dtype)
  197. if datatype == np.ndarray:
  198. return data
  199. if datatype == list:
  200. return data.tolist()
  201. if datatype == pq.Quantity:
  202. return data * units # set the units
  203. # Array annotations need to be a dict containing arrays
  204. if name == 'array_annotations' and datatype == dict:
  205. # Make sure that array annotations have the correct length
  206. if obj in ['AnalogSignal', 'IrregularlySampledSignal']:
  207. length = n if n is not None else 1
  208. elif obj in ['IrregularlySampledSignal', 'SpikeTrain', 'Epoch', 'Event']:
  209. length = n
  210. else:
  211. raise ValueError("This object cannot have array annotations")
  212. # Generate array annotations
  213. valid = np.array([True, False])
  214. number = np.arange(5)
  215. arr_ann = {'valid': valid[(rand(length) * len(valid)).astype('i')],
  216. 'number': number[(rand(length) * len(number)).astype('i')]}
  217. return arr_ann
  218. # we have gone through everything we know, so it must be something invalid
  219. raise ValueError('Unknown name/datatype combination %s %s' % (name, datatype))
  220. def get_fake_values(cls, annotate=True, seed=None, n=None):
  221. """
  222. Returns a dict containing the default values for all attribute for
  223. a class from neo.core.
  224. If seed is not None, use the seed to set the random number generator.
  225. The seed is incremented by 1 for each successive object.
  226. If annotate is True (default), also add annotations to the values.
  227. """
  228. if hasattr(cls,
  229. 'lower'): # is this a test that cls is a string? better to use isinstance(cls,
  230. # basestring), no?
  231. cls = class_by_name[cls]
  232. # iseed is needed below for generation of array annotations
  233. iseed = None
  234. kwargs = {} # assign attributes
  235. for i, attr in enumerate(cls._necessary_attrs + cls._recommended_attrs):
  236. if seed is not None:
  237. iseed = seed + i
  238. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n)
  239. if 'waveforms' in kwargs: # everything here is to force the kwargs to have len(time) ==
  240. # kwargs["waveforms"].shape[0]
  241. if len(kwargs["times"]) != kwargs["waveforms"].shape[0]:
  242. if len(kwargs["times"]) < kwargs["waveforms"].shape[0]:
  243. dif = kwargs["waveforms"].shape[0] - len(kwargs["times"])
  244. new_times = []
  245. for i in kwargs["times"].magnitude:
  246. new_times.append(i)
  247. np.random.seed(0)
  248. new_times = np.concatenate([new_times, np.random.random(dif)])
  249. kwargs["times"] = pq.Quantity(new_times, units=pq.ms)
  250. else:
  251. kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]]
  252. # IrregularlySampledSignal
  253. if 'times' in kwargs and 'signal' in kwargs:
  254. kwargs['times'] = kwargs['times'][:len(kwargs['signal'])]
  255. kwargs['signal'] = kwargs['signal'][:len(kwargs['times'])]
  256. if annotate:
  257. kwargs.update(get_annotations())
  258. # Make sure that array annotations have the right length
  259. if cls in [IrregularlySampledSignal, AnalogSignal]:
  260. try:
  261. n = len(kwargs['signal'][0])
  262. # If only 1 signal, len(int) is called, this raises a TypeError
  263. except TypeError:
  264. n = 1
  265. elif cls in [SpikeTrain, Event, Epoch]:
  266. n = len(kwargs['times'])
  267. # Array annotate any DataObject
  268. if issubclass(cls, DataObject):
  269. new_seed = iseed + 1 if iseed is not None else iseed
  270. kwargs['array_annotations'] = get_fake_value('array_annotations', dict, seed=new_seed,
  271. obj=cls, n=n)
  272. kwargs['seed'] = seed
  273. return kwargs
  274. def get_annotations():
  275. '''
  276. Returns a dict containing the default values for annotations for
  277. a class from neo.core.
  278. '''
  279. return dict([(str(i), ann) for i, ann in enumerate(TEST_ANNOTATIONS)])
  280. def fake_epoch(seed=None, n=1):
  281. """
  282. Create a fake Epoch.
  283. We use this separate function because the attributes of
  284. Epoch are not independent (must all have the same size)
  285. """
  286. kwargs = get_annotations()
  287. if seed is not None:
  288. np.random.seed(seed)
  289. size = np.random.randint(5, 15)
  290. for i, attr in enumerate(Epoch._necessary_attrs + Epoch._recommended_attrs):
  291. if seed is not None:
  292. iseed = seed + i
  293. else:
  294. iseed = None
  295. if attr[0] in ('times', 'durations', 'labels'):
  296. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=Epoch, shape=size)
  297. else:
  298. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=Epoch, n=n)
  299. kwargs['seed'] = seed
  300. obj = Epoch(**kwargs)
  301. return obj
  302. def fake_neo(obj_type="Block", cascade=True, seed=None, n=1):
  303. '''
  304. Create a fake NEO object of a given type. Follows one-to-many
  305. and many-to-many relationships if cascade.
  306. n (default=1) is the number of child objects of each type will be created.
  307. In cases like segment.spiketrains, there will be more than this number
  308. because there will be n for each unit, of which there will be n for
  309. each channelindex, of which there will be n.
  310. '''
  311. if hasattr(obj_type, 'lower'):
  312. cls = class_by_name[obj_type]
  313. else:
  314. cls = obj_type
  315. obj_type = obj_type.__name__
  316. if cls is Epoch:
  317. obj = fake_epoch(seed=seed, n=n)
  318. else:
  319. kwargs = get_fake_values(obj_type, annotate=True, seed=seed, n=n)
  320. obj = cls(**kwargs)
  321. # if not cascading, we don't need to do any of the stuff after this
  322. if not cascade:
  323. return obj
  324. # this is used to signal other containers that they shouldn't duplicate
  325. # data
  326. if obj_type == 'Block':
  327. cascade = 'block'
  328. for i, childname in enumerate(getattr(obj, '_child_objects', [])):
  329. # we create a few of each class
  330. for j in range(n):
  331. if seed is not None:
  332. iseed = 10 * seed + 100 * i + 1000 * j
  333. else:
  334. iseed = None
  335. child = fake_neo(obj_type=childname, cascade=cascade, seed=iseed, n=n)
  336. child.annotate(i=i, j=j)
  337. # if we are creating a block and this is the object's primary
  338. # parent, don't create the object, we will import it from secondary
  339. # containers later
  340. if (cascade == 'block' and len(child._parent_objects) > 0 and obj_type !=
  341. child._parent_objects[-1]):
  342. continue
  343. getattr(obj, _container_name(childname)).append(child)
  344. # need to manually create 'implicit' connections
  345. if obj_type == 'Block':
  346. # connect data objects to segment
  347. for i, chx in enumerate(obj.channel_indexes):
  348. for k, sigarr in enumerate(chx.analogsignals):
  349. obj.segments[k].analogsignals.append(sigarr)
  350. for k, sigarr in enumerate(chx.irregularlysampledsignals):
  351. obj.segments[k].irregularlysampledsignals.append(sigarr)
  352. for j, unit in enumerate(chx.units):
  353. for k, train in enumerate(unit.spiketrains):
  354. obj.segments[k].spiketrains.append(train)
  355. # elif obj_type == 'ChannelIndex':
  356. # inds = []
  357. # names = []
  358. # chinds = np.array([unit.channel_indexes[0] for unit in obj.units])
  359. # obj.indexes = np.array(inds, dtype='i')
  360. # obj.channel_names = np.array(names).astype('S')
  361. if hasattr(obj, 'create_many_to_one_relationship'):
  362. obj.create_many_to_one_relationship()
  363. return obj
  364. def clone_object(obj, n=None):
  365. '''
  366. Generate a new object and new objects with the same rules as the original.
  367. '''
  368. if hasattr(obj, '__iter__') and not hasattr(obj, 'ndim'):
  369. return [clone_object(iobj, n=n) for iobj in obj]
  370. cascade = hasattr(obj, 'children') and len(obj.children)
  371. if n is not None:
  372. pass
  373. elif cascade:
  374. n = min(len(getattr(obj, cont)) for cont in obj._child_containers)
  375. else:
  376. n = 0
  377. seed = obj.annotations.get('seed', None)
  378. newobj = fake_neo(obj.__class__, cascade=cascade, seed=seed, n=n)
  379. if 'i' in obj.annotations:
  380. newobj.annotate(i=obj.annotations['i'], j=obj.annotations['j'])
  381. return newobj