Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

nestio.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading output files from NEST simulations
  4. ( http://www.nest-simulator.org/ ).
  5. Tested with NEST2.10.0
  6. Depends on: numpy, quantities
  7. Supported: Read
  8. Authors: Julia Sprenger, Maximilian Schmidt, Johanna Senk
  9. """
  10. # needed for Python3 compatibility
  11. from __future__ import absolute_import
  12. import os.path
  13. import warnings
  14. from datetime import datetime
  15. import numpy as np
  16. import quantities as pq
  17. from neo.io.baseio import BaseIO
  18. from neo.core import Block, Segment, SpikeTrain, AnalogSignal
  19. value_type_dict = {'V': pq.mV,
  20. 'I': pq.pA,
  21. 'g': pq.CompoundUnit("10^-9*S"),
  22. 'no type': pq.dimensionless}
  23. class NestIO(BaseIO):
  24. """
  25. Class for reading NEST output files. GDF files for the spike data and DAT
  26. files for analog signals are possible.
  27. Usage:
  28. from neo.io.nestio import NestIO
  29. files = ['membrane_voltages-1261-0.dat',
  30. 'spikes-1258-0.gdf']
  31. r = NestIO(filenames=files)
  32. seg = r.read_segment(gid_list=[], t_start=400 * pq.ms,
  33. t_stop=600 * pq.ms,
  34. id_column_gdf=0, time_column_gdf=1,
  35. id_column_dat=0, time_column_dat=1,
  36. value_columns_dat=2)
  37. """
  38. is_readable = True # class supports reading, but not writing
  39. is_writable = False
  40. supported_objects = [SpikeTrain, AnalogSignal, Segment, Block]
  41. readable_objects = [SpikeTrain, AnalogSignal, Segment, Block]
  42. has_header = False
  43. is_streameable = False
  44. write_params = None # writing is not supported
  45. name = 'nest'
  46. extensions = ['gdf', 'dat']
  47. mode = 'file'
  48. def __init__(self, filenames=None):
  49. """
  50. Parameters
  51. ----------
  52. filenames: string or list of strings, default=None
  53. The filename or list of filenames to load.
  54. """
  55. if isinstance(filenames, str):
  56. filenames = [filenames]
  57. self.filenames = filenames
  58. self.avail_formats = {}
  59. self.avail_IOs = {}
  60. for filename in filenames:
  61. path, ext = os.path.splitext(filename)
  62. ext = ext.strip('.')
  63. if ext in self.extensions:
  64. if ext in self.avail_IOs:
  65. raise ValueError('Received multiple files with "%s" '
  66. 'extention. Can only load single file of '
  67. 'this type.' % ext)
  68. self.avail_IOs[ext] = ColumnIO(filename)
  69. self.avail_formats[ext] = path
  70. def __read_analogsignals(self, gid_list, time_unit, t_start=None,
  71. t_stop=None, sampling_period=None,
  72. id_column=0, time_column=1,
  73. value_columns=2, value_types=None,
  74. value_units=None, lazy=False):
  75. """
  76. Internal function called by read_analogsignal() and read_segment().
  77. """
  78. if 'dat' not in self.avail_formats:
  79. raise ValueError('Can not load analogsignals. No DAT file '
  80. 'provided.')
  81. # checking gid input parameters
  82. gid_list, id_column = self._check_input_gids(gid_list, id_column)
  83. # checking time input parameters
  84. t_start, t_stop = self._check_input_times(t_start, t_stop,
  85. mandatory=False)
  86. # checking value input parameters
  87. (value_columns, value_types, value_units) = \
  88. self._check_input_values_parameters(value_columns, value_types,
  89. value_units)
  90. # defining standard column order for internal usage
  91. # [id_column, time_column, value_column1, value_column2, ...]
  92. column_ids = [id_column, time_column] + value_columns
  93. for i, cid in enumerate(column_ids):
  94. if cid is None:
  95. column_ids[i] = -1
  96. # assert that no single column is assigned twice
  97. column_list = [id_column, time_column] + value_columns
  98. column_list_no_None = [c for c in column_list if c is not None]
  99. if len(np.unique(column_list_no_None)) < len(column_list_no_None):
  100. raise ValueError(
  101. 'One or more columns have been specified to contain '
  102. 'the same data. Columns were specified to %s.'
  103. '' % column_list_no_None)
  104. # extracting condition and sorting parameters for raw data loading
  105. (condition, condition_column,
  106. sorting_column) = self._get_conditions_and_sorting(id_column,
  107. time_column,
  108. gid_list,
  109. t_start,
  110. t_stop)
  111. # loading raw data columns
  112. data = self.avail_IOs['dat'].get_columns(
  113. column_ids=column_ids,
  114. condition=condition,
  115. condition_column=condition_column,
  116. sorting_columns=sorting_column)
  117. sampling_period = self._check_input_sampling_period(sampling_period,
  118. time_column,
  119. time_unit,
  120. data)
  121. analogsignal_list = []
  122. if not lazy:
  123. # extracting complete gid list for anasig generation
  124. if (gid_list == []) and id_column is not None:
  125. gid_list = np.unique(data[:, id_column])
  126. # generate analogsignals for each neuron ID
  127. for i in gid_list:
  128. selected_ids = self._get_selected_ids(
  129. i, id_column, time_column, t_start, t_stop, time_unit,
  130. data)
  131. # extract starting time of analogsignal
  132. if (time_column is not None) and data.size:
  133. anasig_start_time = data[selected_ids[0], 1] * time_unit
  134. else:
  135. # set t_start equal to sampling_period because NEST starts
  136. # recording only after 1 sampling_period
  137. anasig_start_time = 1. * sampling_period
  138. # create one analogsignal per value column requested
  139. for v_id, value_column in enumerate(value_columns):
  140. signal = data[
  141. selected_ids[0]:selected_ids[1], value_column]
  142. # create AnalogSignal objects and annotate them with
  143. # the neuron ID
  144. analogsignal_list.append(AnalogSignal(
  145. signal * value_units[v_id],
  146. sampling_period=sampling_period,
  147. t_start=anasig_start_time,
  148. id=i,
  149. type=value_types[v_id]))
  150. # check for correct length of analogsignal
  151. assert (analogsignal_list[-1].t_stop ==
  152. anasig_start_time + len(signal) * sampling_period)
  153. return analogsignal_list
  154. def __read_spiketrains(self, gdf_id_list, time_unit,
  155. t_start, t_stop, id_column,
  156. time_column, **args):
  157. """
  158. Internal function for reading multiple spiketrains at once.
  159. This function is called by read_spiketrain() and read_segment().
  160. """
  161. if 'gdf' not in self.avail_IOs:
  162. raise ValueError('Can not load spiketrains. No GDF file provided.')
  163. # assert that the file contains spike times
  164. if time_column is None:
  165. raise ValueError('Time column is None. No spike times to '
  166. 'be read in.')
  167. gdf_id_list, id_column = self._check_input_gids(gdf_id_list, id_column)
  168. t_start, t_stop = self._check_input_times(t_start, t_stop,
  169. mandatory=True)
  170. # assert that no single column is assigned twice
  171. if id_column == time_column:
  172. raise ValueError('One or more columns have been specified to '
  173. 'contain the same data.')
  174. # defining standard column order for internal usage
  175. # [id_column, time_column, value_column1, value_column2, ...]
  176. column_ids = [id_column, time_column]
  177. for i, cid in enumerate(column_ids):
  178. if cid is None:
  179. column_ids[i] = -1
  180. (condition, condition_column, sorting_column) = \
  181. self._get_conditions_and_sorting(id_column, time_column,
  182. gdf_id_list, t_start, t_stop)
  183. data = self.avail_IOs['gdf'].get_columns(
  184. column_ids=column_ids,
  185. condition=condition,
  186. condition_column=condition_column,
  187. sorting_columns=sorting_column)
  188. # create a list of SpikeTrains for all neuron IDs in gdf_id_list
  189. # assign spike times to neuron IDs if id_column is given
  190. if id_column is not None:
  191. if (gdf_id_list == []) and id_column is not None:
  192. gdf_id_list = np.unique(data[:, id_column])
  193. spiketrain_list = []
  194. for nid in gdf_id_list:
  195. selected_ids = self._get_selected_ids(nid, id_column,
  196. time_column, t_start,
  197. t_stop, time_unit, data)
  198. times = data[selected_ids[0]:selected_ids[1], time_column]
  199. spiketrain_list.append(SpikeTrain(
  200. times, units=time_unit,
  201. t_start=t_start, t_stop=t_stop,
  202. id=nid, **args))
  203. # if id_column is not given, all spike times are collected in one
  204. # spike train with id=None
  205. else:
  206. train = data[:, time_column]
  207. spiketrain_list = [SpikeTrain(train, units=time_unit,
  208. t_start=t_start, t_stop=t_stop,
  209. id=None, **args)]
  210. return spiketrain_list
  211. def _check_input_times(self, t_start, t_stop, mandatory=True):
  212. """
  213. Checks input times for existence and setting default values if
  214. necessary.
  215. t_start: pq.quantity.Quantity, start time of the time range to load.
  216. t_stop: pq.quantity.Quantity, stop time of the time range to load.
  217. mandatory: bool, if True times can not be None and an error will be
  218. raised. if False, time values of None will be replaced by
  219. -infinity or infinity, respectively. default: True.
  220. """
  221. if t_stop is None:
  222. if mandatory:
  223. raise ValueError('No t_start specified.')
  224. else:
  225. t_stop = np.inf * pq.s
  226. if t_start is None:
  227. if mandatory:
  228. raise ValueError('No t_stop specified.')
  229. else:
  230. t_start = -np.inf * pq.s
  231. for time in (t_start, t_stop):
  232. if not isinstance(time, pq.quantity.Quantity):
  233. raise TypeError('Time value (%s) is not a quantity.' % time)
  234. return t_start, t_stop
  235. def _check_input_values_parameters(self, value_columns, value_types,
  236. value_units):
  237. """
  238. Checks value parameters for consistency.
  239. value_columns: int, column id containing the value to load.
  240. value_types: list of strings, type of values.
  241. value_units: list of units of the value columns.
  242. Returns
  243. adjusted list of [value_columns, value_types, value_units]
  244. """
  245. if value_columns is None:
  246. raise ValueError('No value column provided.')
  247. if isinstance(value_columns, int):
  248. value_columns = [value_columns]
  249. if value_types is None:
  250. value_types = ['no type'] * len(value_columns)
  251. elif isinstance(value_types, str):
  252. value_types = [value_types]
  253. # translating value types into units as far as possible
  254. if value_units is None:
  255. short_value_types = [vtype.split('_')[0] for vtype in value_types]
  256. if not all([svt in value_type_dict for svt in short_value_types]):
  257. raise ValueError('Can not interpret value types '
  258. '"%s"' % value_types)
  259. value_units = [value_type_dict[svt] for svt in short_value_types]
  260. # checking for same number of value types, units and columns
  261. if not (len(value_types) == len(value_units) == len(value_columns)):
  262. raise ValueError('Length of value types, units and columns does '
  263. 'not match (%i,%i,%i)' % (len(value_types),
  264. len(value_units),
  265. len(value_columns)))
  266. if not all([isinstance(vunit, pq.UnitQuantity) for vunit in
  267. value_units]):
  268. raise ValueError('No value unit or standard value type specified.')
  269. return value_columns, value_types, value_units
  270. def _check_input_gids(self, gid_list, id_column):
  271. """
  272. Checks gid values and column for consistency.
  273. gid_list: list of int or None, gid to load.
  274. id_column: int, id of the column containing the gids.
  275. Returns
  276. adjusted list of [gid_list, id_column].
  277. """
  278. if gid_list is None:
  279. gid_list = [gid_list]
  280. if None in gid_list and id_column is not None:
  281. raise ValueError('No neuron IDs specified but file contains '
  282. 'neuron IDs in column %s. Specify empty list to '
  283. 'retrieve spiketrains of all neurons.'
  284. '' % str(id_column))
  285. if gid_list != [None] and id_column is None:
  286. raise ValueError('Specified neuron IDs to be %s, but no ID column '
  287. 'specified.' % gid_list)
  288. return gid_list, id_column
  289. def _check_input_sampling_period(self, sampling_period, time_column,
  290. time_unit, data):
  291. """
  292. Checks sampling period, times and time unit for consistency.
  293. sampling_period: pq.quantity.Quantity, sampling period of data to load.
  294. time_column: int, column id of times in data to load.
  295. time_unit: pq.quantity.Quantity, unit of time used in the data to load.
  296. data: numpy array, the data to be loaded / interpreted.
  297. Returns
  298. pq.quantities.Quantity object, the updated sampling period.
  299. """
  300. if sampling_period is None:
  301. if time_column is not None:
  302. data_sampling = np.unique(
  303. np.diff(sorted(np.unique(data[:, 1]))))
  304. if len(data_sampling) > 1:
  305. raise ValueError('Different sampling distances found in '
  306. 'data set (%s)' % data_sampling)
  307. else:
  308. dt = data_sampling[0]
  309. else:
  310. raise ValueError('Can not estimate sampling rate without time '
  311. 'column id provided.')
  312. sampling_period = pq.CompoundUnit(str(dt) + '*'
  313. + time_unit.units.u_symbol)
  314. elif not isinstance(sampling_period, pq.UnitQuantity):
  315. raise ValueError("sampling_period is not specified as a unit.")
  316. return sampling_period
  317. def _get_conditions_and_sorting(self, id_column, time_column, gid_list,
  318. t_start, t_stop):
  319. """
  320. Calculates the condition, condition_column and sorting_column based on
  321. other parameters supplied for loading the data.
  322. id_column: int, id of the column containing gids.
  323. time_column: int, id of the column containing times.
  324. gid_list: list of int, gid to be loaded.
  325. t_start: pq.quantity.Quantity, start of the time range to be loaded.
  326. t_stop: pq.quantity.Quantity, stop of the time range to be loaded.
  327. Returns
  328. updated [condition, condition_column, sorting_column].
  329. """
  330. condition, condition_column = None, None
  331. sorting_column = []
  332. curr_id = 0
  333. if ((gid_list != [None]) and (gid_list is not None)):
  334. if gid_list != []:
  335. condition = lambda x: x in gid_list
  336. condition_column = id_column
  337. sorting_column.append(curr_id) # Sorting according to gids first
  338. curr_id += 1
  339. if time_column is not None:
  340. sorting_column.append(curr_id) # Sorting according to time
  341. curr_id += 1
  342. elif t_start != -np.inf and t_stop != np.inf:
  343. warnings.warn('Ignoring t_start and t_stop parameters, because no '
  344. 'time column id is provided.')
  345. if sorting_column == []:
  346. sorting_column = None
  347. else:
  348. sorting_column = sorting_column[::-1]
  349. return condition, condition_column, sorting_column
  350. def _get_selected_ids(self, gid, id_column, time_column, t_start, t_stop,
  351. time_unit, data):
  352. """
  353. Calculates the data range to load depending on the selected gid
  354. and the provided time range (t_start, t_stop)
  355. gid: int, gid to be loaded.
  356. id_column: int, id of the column containing gids.
  357. time_column: int, id of the column containing times.
  358. t_start: pq.quantity.Quantity, start of the time range to load.
  359. t_stop: pq.quantity.Quantity, stop of the time range to load.
  360. time_unit: pq.quantity.Quantity, time unit of the data to load.
  361. data: numpy array, data to load.
  362. Returns
  363. list of selected gids
  364. """
  365. gid_ids = np.array([0, data.shape[0]])
  366. if id_column is not None:
  367. gid_ids = np.array([np.searchsorted(data[:, 0], gid, side='left'),
  368. np.searchsorted(data[:, 0], gid, side='right')])
  369. gid_data = data[gid_ids[0]:gid_ids[1], :]
  370. # select only requested time range
  371. id_shifts = np.array([0, 0])
  372. if time_column is not None:
  373. id_shifts[0] = np.searchsorted(gid_data[:, 1],
  374. t_start.rescale(
  375. time_unit).magnitude,
  376. side='left')
  377. id_shifts[1] = (np.searchsorted(gid_data[:, 1],
  378. t_stop.rescale(
  379. time_unit).magnitude,
  380. side='left') - gid_data.shape[0])
  381. selected_ids = gid_ids + id_shifts
  382. return selected_ids
  383. def read_block(self, gid_list=None, time_unit=pq.ms, t_start=None,
  384. t_stop=None, sampling_period=None, id_column_dat=0,
  385. time_column_dat=1, value_columns_dat=2,
  386. id_column_gdf=0, time_column_gdf=1, value_types=None,
  387. value_units=None, lazy=False, cascade=True):
  388. seg = self.read_segment(gid_list, time_unit, t_start,
  389. t_stop, sampling_period, id_column_dat,
  390. time_column_dat, value_columns_dat,
  391. id_column_gdf, time_column_gdf, value_types,
  392. value_units, lazy, cascade)
  393. blk = Block(file_origin=seg.file_origin, file_datetime=seg.file_datetime)
  394. blk.segments.append(seg)
  395. seg.block = blk
  396. return blk
  397. def read_segment(self, gid_list=None, time_unit=pq.ms, t_start=None,
  398. t_stop=None, sampling_period=None, id_column_dat=0,
  399. time_column_dat=1, value_columns_dat=2,
  400. id_column_gdf=0, time_column_gdf=1, value_types=None,
  401. value_units=None, lazy=False, cascade=True):
  402. """
  403. Reads a Segment which contains SpikeTrain(s) with specified neuron IDs
  404. from the GDF data.
  405. Arguments
  406. ----------
  407. gid_list : list, default: None
  408. A list of GDF IDs of which to return SpikeTrain(s). gid_list must
  409. be specified if the GDF file contains neuron IDs, the default None
  410. then raises an error. Specify an empty list [] to retrieve the spike
  411. trains of all neurons.
  412. time_unit : Quantity (time), optional, default: quantities.ms
  413. The time unit of recorded time stamps in DAT as well as GDF files.
  414. t_start : Quantity (time), optional, default: 0 * pq.ms
  415. Start time of SpikeTrain.
  416. t_stop : Quantity (time), default: None
  417. Stop time of SpikeTrain. t_stop must be specified, the default None
  418. raises an error.
  419. sampling_period : Quantity (frequency), optional, default: None
  420. Sampling period of the recorded data.
  421. id_column_dat : int, optional, default: 0
  422. Column index of neuron IDs in the DAT file.
  423. time_column_dat : int, optional, default: 1
  424. Column index of time stamps in the DAT file.
  425. value_columns_dat : int, optional, default: 2
  426. Column index of the analog values recorded in the DAT file.
  427. id_column_gdf : int, optional, default: 0
  428. Column index of neuron IDs in the GDF file.
  429. time_column_gdf : int, optional, default: 1
  430. Column index of time stamps in the GDF file.
  431. value_types : str, optional, default: None
  432. Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'
  433. value_units : Quantity (amplitude), default: None
  434. The physical unit of the recorded signal values.
  435. lazy : bool, optional, default: False
  436. cascade : bool, optional, default: True
  437. Returns
  438. -------
  439. seg : Segment
  440. The Segment contains one SpikeTrain and one AnalogSignal for
  441. each ID in gid_list.
  442. """
  443. if isinstance(gid_list, tuple):
  444. if gid_list[0] > gid_list[1]:
  445. raise ValueError('The second entry in gid_list must be '
  446. 'greater or equal to the first entry.')
  447. gid_list = range(gid_list[0], gid_list[1] + 1)
  448. # __read_xxx() needs a list of IDs
  449. if gid_list is None:
  450. gid_list = [None]
  451. # create an empty Segment
  452. seg = Segment(file_origin=",".join(self.filenames))
  453. seg.file_datetime = datetime.fromtimestamp(os.stat(self.filenames[0]).st_mtime)
  454. # todo: rather than take the first file for the timestamp, we should take the oldest
  455. # in practice, there won't be much difference
  456. if cascade:
  457. # Load analogsignals and attach to Segment
  458. if 'dat' in self.avail_formats:
  459. seg.analogsignals = self.__read_analogsignals(
  460. gid_list,
  461. time_unit,
  462. t_start,
  463. t_stop,
  464. sampling_period=sampling_period,
  465. id_column=id_column_dat,
  466. time_column=time_column_dat,
  467. value_columns=value_columns_dat,
  468. value_types=value_types,
  469. value_units=value_units,
  470. lazy=lazy)
  471. if 'gdf' in self.avail_formats:
  472. seg.spiketrains = self.__read_spiketrains(
  473. gid_list,
  474. time_unit,
  475. t_start,
  476. t_stop,
  477. id_column=id_column_gdf,
  478. time_column=time_column_gdf)
  479. return seg
  480. def read_analogsignal(self, gid=None, time_unit=pq.ms, t_start=None,
  481. t_stop=None, sampling_period=None, id_column=0,
  482. time_column=1, value_column=2, value_type=None,
  483. value_unit=None, lazy=False):
  484. """
  485. Reads an AnalogSignal with specified neuron ID from the DAT data.
  486. Arguments
  487. ----------
  488. gid : int, default: None
  489. The GDF ID of the returned SpikeTrain. gdf_id must be specified if
  490. the GDF file contains neuron IDs, the default None then raises an
  491. error. Specify an empty list [] to retrieve the spike trains of all
  492. neurons.
  493. time_unit : Quantity (time), optional, default: quantities.ms
  494. The time unit of recorded time stamps.
  495. t_start : Quantity (time), optional, default: 0 * pq.ms
  496. Start time of SpikeTrain.
  497. t_stop : Quantity (time), default: None
  498. Stop time of SpikeTrain. t_stop must be specified, the default None
  499. raises an error.
  500. sampling_period : Quantity (frequency), optional, default: None
  501. Sampling period of the recorded data.
  502. id_column : int, optional, default: 0
  503. Column index of neuron IDs.
  504. time_column : int, optional, default: 1
  505. Column index of time stamps.
  506. value_column : int, optional, default: 2
  507. Column index of the analog values recorded.
  508. value_type : str, optional, default: None
  509. Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'.
  510. value_unit : Quantity (amplitude), default: None
  511. The physical unit of the recorded signal values.
  512. lazy : bool, optional, default: False
  513. Returns
  514. -------
  515. spiketrain : SpikeTrain
  516. The requested SpikeTrain object with an annotation 'id'
  517. corresponding to the gdf_id parameter.
  518. """
  519. # __read_spiketrains() needs a list of IDs
  520. return self.__read_analogsignals([gid], time_unit,
  521. t_start, t_stop,
  522. sampling_period=sampling_period,
  523. id_column=id_column,
  524. time_column=time_column,
  525. value_columns=value_column,
  526. value_types=value_type,
  527. value_units=value_unit,
  528. lazy=lazy)[0]
  529. def read_spiketrain(
  530. self, gdf_id=None, time_unit=pq.ms, t_start=None, t_stop=None,
  531. id_column=0, time_column=1, lazy=False, cascade=True, **args):
  532. """
  533. Reads a SpikeTrain with specified neuron ID from the GDF data.
  534. Arguments
  535. ----------
  536. gdf_id : int, default: None
  537. The GDF ID of the returned SpikeTrain. gdf_id must be specified if
  538. the GDF file contains neuron IDs. Providing [] loads all available
  539. IDs.
  540. time_unit : Quantity (time), optional, default: quantities.ms
  541. The time unit of recorded time stamps.
  542. t_start : Quantity (time), default: None
  543. Start time of SpikeTrain. t_start must be specified.
  544. t_stop : Quantity (time), default: None
  545. Stop time of SpikeTrain. t_stop must be specified.
  546. id_column : int, optional, default: 0
  547. Column index of neuron IDs.
  548. time_column : int, optional, default: 1
  549. Column index of time stamps.
  550. lazy : bool, optional, default: False
  551. cascade : bool, optional, default: True
  552. Returns
  553. -------
  554. spiketrain : SpikeTrain
  555. The requested SpikeTrain object with an annotation 'id'
  556. corresponding to the gdf_id parameter.
  557. """
  558. if (not isinstance(gdf_id, int)) and gdf_id is not None:
  559. raise ValueError('gdf_id has to be of type int or None.')
  560. if gdf_id is None and id_column is not None:
  561. raise ValueError('No neuron ID specified but file contains '
  562. 'neuron IDs in column ' + str(id_column) + '.')
  563. return self.__read_spiketrains([gdf_id], time_unit,
  564. t_start, t_stop,
  565. id_column, time_column,
  566. **args)[0]
  567. class ColumnIO:
  568. '''
  569. Class for reading an ASCII file containing multiple columns of data.
  570. '''
  571. def __init__(self, filename):
  572. """
  573. filename: string, path to ASCII file to read.
  574. """
  575. self.filename = filename
  576. # read the first line to check the data type (int or float) of the data
  577. f = open(self.filename)
  578. line = f.readline()
  579. additional_parameters = {}
  580. if '.' not in line:
  581. additional_parameters['dtype'] = np.int32
  582. self.data = np.loadtxt(self.filename, **additional_parameters)
  583. if len(self.data.shape) == 1:
  584. self.data = self.data[:, np.newaxis]
  585. def get_columns(self, column_ids='all', condition=None,
  586. condition_column=None, sorting_columns=None):
  587. """
  588. column_ids : 'all' or list of int, the ids of columns to
  589. extract.
  590. condition : None or function, which is applied to each row to evaluate
  591. if it should be included in the result.
  592. Needs to return a bool value.
  593. condition_column : int, id of the column on which the condition
  594. function is applied to
  595. sorting_columns : int or list of int, column ids to sort by.
  596. List entries have to be ordered by increasing sorting
  597. priority!
  598. Returns
  599. -------
  600. numpy array containing the requested data.
  601. """
  602. if column_ids == [] or column_ids == 'all':
  603. column_ids = range(self.data.shape[-1])
  604. if isinstance(column_ids, (int, float)):
  605. column_ids = [column_ids]
  606. column_ids = np.array(column_ids)
  607. if column_ids is not None:
  608. if max(column_ids) >= len(self.data) - 1:
  609. raise ValueError('Can not load column ID %i. File contains '
  610. 'only %i columns' % (max(column_ids),
  611. len(self.data)))
  612. if sorting_columns is not None:
  613. if isinstance(sorting_columns, int):
  614. sorting_columns = [sorting_columns]
  615. if (max(sorting_columns) >= self.data.shape[1]):
  616. raise ValueError('Can not sort by column ID %i. File contains '
  617. 'only %i columns' % (max(sorting_columns),
  618. self.data.shape[1]))
  619. # Starting with whole dataset being selected for return
  620. selected_data = self.data
  621. # Apply filter condition to rows
  622. if condition and (condition_column is None):
  623. raise ValueError('Filter condition provided, but no '
  624. 'condition_column ID provided')
  625. elif (condition_column is not None) and (condition is None):
  626. warnings.warn('Condition column ID provided, but no condition '
  627. 'given. No filtering will be performed.')
  628. elif (condition is not None) and (condition_column is not None):
  629. condition_function = np.vectorize(condition)
  630. mask = condition_function(
  631. selected_data[
  632. :, condition_column]).astype(bool)
  633. selected_data = selected_data[mask, :]
  634. # Apply sorting if requested
  635. if sorting_columns is not None:
  636. values_to_sort = selected_data[:, sorting_columns].T
  637. ordered_ids = np.lexsort(tuple(values_to_sort[i] for i in
  638. range(len(values_to_sort))))
  639. selected_data = selected_data[ordered_ids, :]
  640. # Select only requested columns
  641. selected_data = selected_data[:, column_ids]
  642. return selected_data