generate_datasets.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. '''
  2. Generate datasets for testing
  3. '''
  4. from datetime import datetime
  5. import numpy as np
  6. from numpy.random import rand
  7. import quantities as pq
  8. from neo.core import (AnalogSignal, Block, Epoch, Event, IrregularlySampledSignal, ChannelIndex,
  9. Segment, SpikeTrain, Unit, ImageSequence, CircularRegionOfInterest,
  10. RectangularRegionOfInterest, PolygonRegionOfInterest, class_by_name)
  11. from neo.core.baseneo import _container_name
  12. from neo.core.dataobject import DataObject
  13. TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test", datetime.fromtimestamp(424242424), None]
  14. def generate_one_simple_block(block_name='block_0', nb_segment=3, supported_objects=[], **kws):
  15. if supported_objects and Block not in supported_objects:
  16. raise ValueError('Block must be in supported_objects')
  17. bl = Block() # name = block_name)
  18. objects = supported_objects
  19. if Segment in objects:
  20. for s in range(nb_segment):
  21. seg = generate_one_simple_segment(seg_name="seg" + str(s), supported_objects=objects,
  22. **kws)
  23. bl.segments.append(seg)
  24. # if RecordingChannel in objects:
  25. # populate_RecordingChannel(bl)
  26. bl.create_many_to_one_relationship()
  27. return bl
  28. def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_analogsignal=4,
  29. t_start=0. * pq.s, sampling_rate=10 * pq.kHz, duration=6. * pq.s,
  30. nb_spiketrain=6, spikerate_range=[.5 * pq.Hz, 12 * pq.Hz],
  31. event_types={'stim': ['a', 'b', 'c', 'd'],
  32. 'enter_zone': ['one', 'two'],
  33. 'color': ['black', 'yellow', 'green'], },
  34. event_size_range=[5, 20],
  35. epoch_types={'animal state': ['Sleep', 'Freeze', 'Escape'],
  36. 'light': ['dark', 'lighted']},
  37. epoch_duration_range=[.5, 3.],
  38. # this should be multiplied by pq.s, no?
  39. array_annotations={'valid': np.array([True, False]),
  40. 'number': np.array(range(5))}
  41. ):
  42. if supported_objects and Segment not in supported_objects:
  43. raise ValueError('Segment must be in supported_objects')
  44. seg = Segment(name=seg_name)
  45. if AnalogSignal in supported_objects:
  46. for a in range(nb_analogsignal):
  47. anasig = AnalogSignal(rand(int((sampling_rate * duration).simplified)),
  48. sampling_rate=sampling_rate,
  49. t_start=t_start, units=pq.mV, channel_index=a,
  50. name='sig %d for segment %s' % (a, seg.name))
  51. seg.analogsignals.append(anasig)
  52. if SpikeTrain in supported_objects:
  53. for s in range(nb_spiketrain):
  54. spikerate = rand() * np.diff(spikerate_range)
  55. spikerate += spikerate_range[0].magnitude
  56. # spikedata = rand(int((spikerate*duration).simplified))*duration
  57. # sptr = SpikeTrain(spikedata,
  58. # t_start=t_start, t_stop=t_start+duration)
  59. # #, name = 'spiketrain %d'%s)
  60. spikes = rand(int((spikerate * duration).simplified))
  61. spikes.sort() # spikes are supposed to be an ascending sequence
  62. sptr = SpikeTrain(spikes * duration, t_start=t_start, t_stop=t_start + duration)
  63. sptr.annotations['channel_index'] = s
  64. # Randomly generate array_annotations from given options
  65. arr_ann = {key: value[(rand(len(spikes)) * len(value)).astype('i')] for (key, value) in
  66. array_annotations.items()}
  67. sptr.array_annotate(**arr_ann)
  68. seg.spiketrains.append(sptr)
  69. if Event in supported_objects:
  70. for name, labels in event_types.items():
  71. evt_size = rand() * np.diff(event_size_range)
  72. evt_size += event_size_range[0]
  73. evt_size = int(evt_size)
  74. labels = np.array(labels, dtype='U')
  75. labels = labels[(rand(evt_size) * len(labels)).astype('i')]
  76. evt = Event(times=rand(evt_size) * duration, labels=labels)
  77. # Randomly generate array_annotations from given options
  78. arr_ann = {key: value[(rand(evt_size) * len(value)).astype('i')] for (key, value) in
  79. array_annotations.items()}
  80. evt.array_annotate(**arr_ann)
  81. seg.events.append(evt)
  82. if Epoch in supported_objects:
  83. for name, labels in epoch_types.items():
  84. t = 0
  85. times = []
  86. durations = []
  87. while t < duration:
  88. times.append(t)
  89. dur = rand() * (epoch_duration_range[1] - epoch_duration_range[0])
  90. dur += epoch_duration_range[0]
  91. durations.append(dur)
  92. t = t + dur
  93. labels = np.array(labels, dtype='U')
  94. labels = labels[(rand(len(times)) * len(labels)).astype('i')]
  95. assert len(times) == len(durations)
  96. assert len(times) == len(labels)
  97. epc = Epoch(times=pq.Quantity(times, units=pq.s),
  98. durations=pq.Quantity(durations, units=pq.s),
  99. labels=labels,)
  100. assert epc.times.dtype == 'float'
  101. # Randomly generate array_annotations from given options
  102. arr_ann = {key: value[(rand(len(times)) * len(value)).astype('i')] for (key, value) in
  103. array_annotations.items()}
  104. epc.array_annotate(**arr_ann)
  105. seg.epochs.append(epc)
  106. # TODO : Spike, Event
  107. seg.create_many_to_one_relationship()
  108. return seg
  109. def generate_from_supported_objects(supported_objects):
  110. # ~ create_many_to_one_relationship
  111. if not supported_objects:
  112. raise ValueError('No objects specified')
  113. objects = supported_objects
  114. if Block in supported_objects:
  115. higher = generate_one_simple_block(supported_objects=objects)
  116. # Chris we do not create RC and RCG if it is not in objects # there is a test in
  117. # generate_one_simple_block so I removed # finalize_block(higher)
  118. elif Segment in objects:
  119. higher = generate_one_simple_segment(supported_objects=objects)
  120. else:
  121. # TODO
  122. return None
  123. higher.create_many_to_one_relationship()
  124. return higher
  125. def get_fake_value(name, datatype, dim=0, dtype='float', seed=None, units=None, obj=None, n=None,
  126. shape=None):
  127. """
  128. Returns default value for a given attribute based on neo.core
  129. If seed is not None, use the seed to set the random number generator.
  130. """
  131. if not obj:
  132. obj = 'TestObject'
  133. elif not hasattr(obj, 'lower'):
  134. obj = obj.__name__
  135. if (name in ['name', 'file_origin', 'description'] and (datatype != str or dim)):
  136. raise ValueError('{} must be str, not a {}D {}'.format(name, dim, datatype))
  137. if name == 'file_origin':
  138. return 'test_file.txt'
  139. if name == 'name':
  140. return '{}{}'.format(obj, get_fake_value('', datatype, seed=seed))
  141. if name == 'description':
  142. return 'test {} {}'.format(obj, get_fake_value('', datatype, seed=seed))
  143. if seed is not None:
  144. np.random.seed(seed)
  145. if datatype == str:
  146. return str(np.random.randint(100000))
  147. if datatype == int:
  148. return np.random.randint(100)
  149. if datatype == float:
  150. return 1000. * np.random.random()
  151. if datatype == datetime:
  152. return datetime.fromtimestamp(1000000000 * np.random.random())
  153. if (name in ['t_start', 't_stop', 'sampling_rate'] and (datatype != pq.Quantity or dim)):
  154. raise ValueError('{} must be a 0D Quantity, not a {}D {}'.format(name, dim, datatype))
  155. # only put array types below here
  156. if units is not None:
  157. pass
  158. elif name in ['t_start', 't_stop', 'time', 'times', 'duration', 'durations']:
  159. units = pq.millisecond
  160. elif name == 'sampling_rate':
  161. units = pq.Hz
  162. elif datatype == pq.Quantity:
  163. units = np.random.choice(['nA', 'mA', 'A', 'mV', 'V'])
  164. units = getattr(pq, units)
  165. if name == 'sampling_rate':
  166. data = np.array(10000.0)
  167. elif name == 't_start':
  168. data = np.array(0.0)
  169. elif name == 't_stop':
  170. data = np.array(1.0)
  171. elif n and name in ['channel_indexes', 'channel_ids']:
  172. data = np.arange(n)
  173. elif n and name == 'coordinates':
  174. data = np.arange(0, 2*n).reshape((n, 2))
  175. elif n and name == 'channel_names':
  176. data = np.array(["ch%d" % i for i in range(n)])
  177. elif n and name == 'index': # ChannelIndex.index
  178. data = np.random.randint(0, n, n)
  179. elif n and obj == 'AnalogSignal':
  180. if name == 'signal':
  181. size = []
  182. for _ in range(int(dim)):
  183. size.append(np.random.randint(5) + 1)
  184. size[1] = n
  185. data = np.random.random(size) * 1000.
  186. else:
  187. size = []
  188. for _ in range(int(dim)):
  189. if shape is None:
  190. # To ensure consistency, times, labels and durations need to have the same size
  191. if name in ["times", "labels", "durations"]:
  192. size.append(5)
  193. else:
  194. size.append(np.random.randint(5) + 1)
  195. else:
  196. size.append(shape)
  197. data = np.random.random(size)
  198. if name not in ['time', 'times']:
  199. data *= 1000.
  200. if np.dtype(dtype) != np.float64:
  201. data = data.astype(dtype)
  202. if datatype == np.ndarray:
  203. return data
  204. if datatype == list:
  205. return data.tolist()
  206. if datatype == pq.Quantity:
  207. return data * units # set the units
  208. # Array annotations need to be a dict containing arrays
  209. if name == 'array_annotations' and datatype == dict:
  210. # Make sure that array annotations have the correct length
  211. if obj in ['AnalogSignal', 'IrregularlySampledSignal']:
  212. length = n if n is not None else 1
  213. elif obj in ['IrregularlySampledSignal', 'SpikeTrain', 'Epoch', 'Event']:
  214. length = n
  215. else:
  216. raise ValueError("This object cannot have array annotations")
  217. # Generate array annotations
  218. valid = np.array([True, False])
  219. number = np.arange(5)
  220. arr_ann = {'valid': valid[(rand(length) * len(valid)).astype('i')],
  221. 'number': number[(rand(length) * len(number)).astype('i')]}
  222. return arr_ann
  223. # we have gone through everything we know, so it must be something invalid
  224. raise ValueError('Unknown name/datatype combination {} {}'.format(name, datatype))
  225. def get_fake_values(cls, annotate=True, seed=None, n=None):
  226. """
  227. Returns a dict containing the default values for all attribute for
  228. a class from neo.core.
  229. If seed is not None, use the seed to set the random number generator.
  230. The seed is incremented by 1 for each successive object.
  231. If annotate is True (default), also add annotations to the values.
  232. """
  233. if hasattr(cls, 'lower'): # is this a test that cls is a string? better to use isinstance(cls,
  234. # basestring), no?
  235. cls = class_by_name[cls]
  236. # iseed is needed below for generation of array annotations
  237. iseed = None
  238. kwargs = {} # assign attributes
  239. for i, attr in enumerate(cls._necessary_attrs + cls._recommended_attrs):
  240. if seed is not None:
  241. iseed = seed + i
  242. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n)
  243. if 'waveforms' in kwargs: # everything here is to force the kwargs to have len(time) ==
  244. # kwargs["waveforms"].shape[0]
  245. if len(kwargs["times"]) != kwargs["waveforms"].shape[0]:
  246. if len(kwargs["times"]) < kwargs["waveforms"].shape[0]:
  247. dif = kwargs["waveforms"].shape[0] - len(kwargs["times"])
  248. new_times = []
  249. for i in kwargs["times"].magnitude:
  250. new_times.append(i)
  251. np.random.seed(0)
  252. new_times = np.concatenate([new_times, np.random.random(dif)])
  253. kwargs["times"] = pq.Quantity(new_times, units=pq.ms)
  254. else:
  255. kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]]
  256. # IrregularlySampledSignal
  257. if 'times' in kwargs and 'signal' in kwargs:
  258. kwargs['times'] = kwargs['times'][:len(kwargs['signal'])]
  259. kwargs['signal'] = kwargs['signal'][:len(kwargs['times'])]
  260. if annotate:
  261. kwargs.update(get_annotations())
  262. # Make sure that array annotations have the right length
  263. if cls in [IrregularlySampledSignal, AnalogSignal]:
  264. try:
  265. n = len(kwargs['signal'][0])
  266. # If only 1 signal, len(int) is called, this raises a TypeError
  267. except TypeError:
  268. n = 1
  269. elif cls in [SpikeTrain, Event, Epoch]:
  270. n = len(kwargs['times'])
  271. # Array annotate any DataObject except ImageSequence
  272. if issubclass(cls, DataObject) and cls is not ImageSequence:
  273. new_seed = iseed + 1 if iseed is not None else iseed
  274. kwargs['array_annotations'] = get_fake_value('array_annotations', dict, seed=new_seed,
  275. obj=cls, n=n)
  276. kwargs['seed'] = seed
  277. return kwargs
  278. def get_annotations():
  279. '''
  280. Returns a dict containing the default values for annotations for
  281. a class from neo.core.
  282. '''
  283. return {str(i): ann for i, ann in enumerate(TEST_ANNOTATIONS)}
  284. def fake_epoch(seed=None, n=1):
  285. """
  286. Create a fake Epoch.
  287. We use this separate function because the attributes of
  288. Epoch are not independent (must all have the same size)
  289. """
  290. kwargs = get_annotations()
  291. if seed is not None:
  292. np.random.seed(seed)
  293. size = np.random.randint(5, 15)
  294. for i, attr in enumerate(Epoch._necessary_attrs + Epoch._recommended_attrs):
  295. if seed is not None:
  296. iseed = seed + i
  297. else:
  298. iseed = None
  299. if attr[0] in ('times', 'durations', 'labels'):
  300. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=Epoch, shape=size)
  301. else:
  302. kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=Epoch, n=n)
  303. kwargs['seed'] = seed
  304. obj = Epoch(**kwargs)
  305. return obj
  306. def fake_neo(obj_type="Block", cascade=True, seed=None, n=1):
  307. '''
  308. Create a fake NEO object of a given type. Follows one-to-many
  309. and many-to-many relationships if cascade.
  310. n (default=1) is the number of child objects of each type will be created.
  311. In cases like segment.spiketrains, there will be more than this number
  312. because there will be n for each unit, of which there will be n for
  313. each channelindex, of which there will be n.
  314. '''
  315. if hasattr(obj_type, 'lower'):
  316. cls = class_by_name[obj_type]
  317. else:
  318. cls = obj_type
  319. obj_type = obj_type.__name__
  320. if cls is Epoch:
  321. obj = fake_epoch(seed=seed, n=n)
  322. else:
  323. kwargs = get_fake_values(obj_type, annotate=True, seed=seed, n=n)
  324. obj = cls(**kwargs)
  325. # if not cascading, we don't need to do any of the stuff after this
  326. if not cascade:
  327. return obj
  328. # this is used to signal other containers that they shouldn't duplicate
  329. # data
  330. if obj_type == 'Block':
  331. cascade = 'block'
  332. for i, childname in enumerate(getattr(obj, '_child_objects', [])):
  333. # we create a few of each class
  334. if childname == 'Group':
  335. continue # avoid infinite recursion, since Groups can contain Groups
  336. for j in range(n):
  337. if seed is not None:
  338. iseed = 10 * seed + 100 * i + 1000 * j
  339. else:
  340. iseed = None
  341. child = fake_neo(obj_type=childname, cascade=cascade, seed=iseed, n=n)
  342. child.annotate(i=i, j=j)
  343. # if we are creating a block and this is the object's primary
  344. # parent, don't create the object, we will import it from secondary
  345. # containers later
  346. if (cascade == 'block' and len(child._parent_objects) > 0
  347. and obj_type != child._parent_objects[-1]):
  348. continue
  349. getattr(obj, _container_name(childname)).append(child)
  350. # need to manually create 'implicit' connections
  351. if obj_type == 'Block':
  352. # connect data objects to segment
  353. for i, chx in enumerate(obj.channel_indexes):
  354. for k, sigarr in enumerate(chx.analogsignals):
  355. obj.segments[k].analogsignals.append(sigarr)
  356. for k, sigarr in enumerate(chx.irregularlysampledsignals):
  357. obj.segments[k].irregularlysampledsignals.append(sigarr)
  358. for j, unit in enumerate(chx.units):
  359. for k, train in enumerate(unit.spiketrains):
  360. obj.segments[k].spiketrains.append(train)
  361. # elif obj_type == 'ChannelIndex':
  362. # inds = []
  363. # names = []
  364. # chinds = np.array([unit.channel_indexes[0] for unit in obj.units])
  365. # obj.indexes = np.array(inds, dtype='i')
  366. # obj.channel_names = np.array(names).astype('S')
  367. if hasattr(obj, 'create_many_to_one_relationship'):
  368. obj.create_many_to_one_relationship()
  369. return obj
  370. def clone_object(obj, n=None):
  371. '''
  372. Generate a new object and new objects with the same rules as the original.
  373. '''
  374. if hasattr(obj, '__iter__') and not hasattr(obj, 'ndim'):
  375. return [clone_object(iobj, n=n) for iobj in obj]
  376. cascade = hasattr(obj, 'children') and len(obj.children)
  377. if n is not None:
  378. pass
  379. elif cascade:
  380. n = min(len(getattr(obj, cont)) for cont in obj._child_containers)
  381. else:
  382. n = 0
  383. seed = obj.annotations.get('seed', None)
  384. newobj = fake_neo(obj.__class__, cascade=cascade, seed=seed, n=n)
  385. if 'i' in obj.annotations:
  386. newobj.annotate(i=obj.annotations['i'], j=obj.annotations['j'])
  387. return newobj