Scheduled service maintenance on November 22


On Friday, November 22, 2024, between 06:00 CET and 18:00 CET, GIN services will undergo planned maintenance. Extended service interruptions should be expected. We will try to keep downtimes to a minimum, but recommend that users avoid critical tasks, large data uploads, or DOI requests during this time.

We apologize for any inconvenience.

nestio.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # -*- coding: utf-8 -*-
  2. """
  3. Class for reading output files from NEST simulations
  4. ( http://www.nest-simulator.org/ ).
  5. Tested with NEST2.10.0
  6. Depends on: numpy, quantities
  7. Supported: Read
  8. Authors: Julia Sprenger, Maximilian Schmidt, Johanna Senk
  9. """
  10. # needed for Python3 compatibility
  11. from __future__ import absolute_import
  12. import os.path
  13. import warnings
  14. from datetime import datetime
  15. import numpy as np
  16. import quantities as pq
  17. from neo.io.baseio import BaseIO
  18. from neo.core import Block, Segment, SpikeTrain, AnalogSignal
  19. value_type_dict = {'V': pq.mV,
  20. 'I': pq.pA,
  21. 'g': pq.CompoundUnit("10^-9*S"),
  22. 'no type': pq.dimensionless}
  23. class NestIO(BaseIO):
  24. """
  25. Class for reading NEST output files. GDF files for the spike data and DAT
  26. files for analog signals are possible.
  27. Usage:
  28. >>> from neo.io.nestio import NestIO
  29. >>> files = ['membrane_voltages-1261-0.dat',
  30. 'spikes-1258-0.gdf']
  31. >>> r = NestIO(filenames=files)
  32. >>> seg = r.read_segment(gid_list=[], t_start=400 * pq.ms,
  33. t_stop=600 * pq.ms,
  34. id_column_gdf=0, time_column_gdf=1,
  35. id_column_dat=0, time_column_dat=1,
  36. value_columns_dat=2)
  37. """
  38. is_readable = True # class supports reading, but not writing
  39. is_writable = False
  40. supported_objects = [SpikeTrain, AnalogSignal, Segment, Block]
  41. readable_objects = [SpikeTrain, AnalogSignal, Segment, Block]
  42. has_header = False
  43. is_streameable = False
  44. write_params = None # writing is not supported
  45. name = 'nest'
  46. extensions = ['gdf', 'dat']
  47. mode = 'file'
  48. def __init__(self, filenames=None):
  49. """
  50. Parameters
  51. ----------
  52. filenames: string or list of strings, default=None
  53. The filename or list of filenames to load.
  54. """
  55. if isinstance(filenames, str):
  56. filenames = [filenames]
  57. self.filenames = filenames
  58. self.avail_formats = {}
  59. self.avail_IOs = {}
  60. for filename in filenames:
  61. path, ext = os.path.splitext(filename)
  62. ext = ext.strip('.')
  63. if ext in self.extensions:
  64. if ext in self.avail_IOs:
  65. raise ValueError('Received multiple files with "%s" '
  66. 'extention. Can only load single file of '
  67. 'this type.' % ext)
  68. self.avail_IOs[ext] = ColumnIO(filename)
  69. self.avail_formats[ext] = path
  70. def __read_analogsignals(self, gid_list, time_unit, t_start=None,
  71. t_stop=None, sampling_period=None,
  72. id_column=0, time_column=1,
  73. value_columns=2, value_types=None,
  74. value_units=None):
  75. """
  76. Internal function called by read_analogsignal() and read_segment().
  77. """
  78. if 'dat' not in self.avail_formats:
  79. raise ValueError('Can not load analogsignals. No DAT file '
  80. 'provided.')
  81. # checking gid input parameters
  82. gid_list, id_column = self._check_input_gids(gid_list, id_column)
  83. # checking time input parameters
  84. t_start, t_stop = self._check_input_times(t_start, t_stop,
  85. mandatory=False)
  86. # checking value input parameters
  87. (value_columns, value_types, value_units) = \
  88. self._check_input_values_parameters(value_columns, value_types,
  89. value_units)
  90. # defining standard column order for internal usage
  91. # [id_column, time_column, value_column1, value_column2, ...]
  92. column_ids = [id_column, time_column] + value_columns
  93. for i, cid in enumerate(column_ids):
  94. if cid is None:
  95. column_ids[i] = -1
  96. # assert that no single column is assigned twice
  97. column_list = [id_column, time_column] + value_columns
  98. column_list_no_None = [c for c in column_list if c is not None]
  99. if len(np.unique(column_list_no_None)) < len(column_list_no_None):
  100. raise ValueError(
  101. 'One or more columns have been specified to contain '
  102. 'the same data. Columns were specified to %s.'
  103. '' % column_list_no_None)
  104. # extracting condition and sorting parameters for raw data loading
  105. (condition, condition_column,
  106. sorting_column) = self._get_conditions_and_sorting(id_column,
  107. time_column,
  108. gid_list,
  109. t_start,
  110. t_stop)
  111. # loading raw data columns
  112. data = self.avail_IOs['dat'].get_columns(
  113. column_ids=column_ids,
  114. condition=condition,
  115. condition_column=condition_column,
  116. sorting_columns=sorting_column)
  117. sampling_period = self._check_input_sampling_period(sampling_period,
  118. time_column,
  119. time_unit,
  120. data)
  121. analogsignal_list = []
  122. # extracting complete gid list for anasig generation
  123. if (gid_list == []) and id_column is not None:
  124. gid_list = np.unique(data[:, id_column])
  125. # generate analogsignals for each neuron ID
  126. for i in gid_list:
  127. selected_ids = self._get_selected_ids(
  128. i, id_column, time_column, t_start, t_stop, time_unit,
  129. data)
  130. # extract starting time of analogsignal
  131. if (time_column is not None) and data.size:
  132. anasig_start_time = data[selected_ids[0], 1] * time_unit
  133. else:
  134. # set t_start equal to sampling_period because NEST starts
  135. # recording only after 1 sampling_period
  136. anasig_start_time = 1. * sampling_period
  137. # create one analogsignal per value column requested
  138. for v_id, value_column in enumerate(value_columns):
  139. signal = data[
  140. selected_ids[0]:selected_ids[1], value_column]
  141. # create AnalogSignal objects and annotate them with
  142. # the neuron ID
  143. analogsignal_list.append(AnalogSignal(
  144. signal * value_units[v_id],
  145. sampling_period=sampling_period,
  146. t_start=anasig_start_time,
  147. id=i,
  148. type=value_types[v_id]))
  149. # check for correct length of analogsignal
  150. assert (analogsignal_list[-1].t_stop
  151. == anasig_start_time + len(signal) * sampling_period)
  152. return analogsignal_list
  153. def __read_spiketrains(self, gdf_id_list, time_unit,
  154. t_start, t_stop, id_column,
  155. time_column, **args):
  156. """
  157. Internal function for reading multiple spiketrains at once.
  158. This function is called by read_spiketrain() and read_segment().
  159. """
  160. if 'gdf' not in self.avail_IOs:
  161. raise ValueError('Can not load spiketrains. No GDF file provided.')
  162. # assert that the file contains spike times
  163. if time_column is None:
  164. raise ValueError('Time column is None. No spike times to '
  165. 'be read in.')
  166. gdf_id_list, id_column = self._check_input_gids(gdf_id_list, id_column)
  167. t_start, t_stop = self._check_input_times(t_start, t_stop,
  168. mandatory=True)
  169. # assert that no single column is assigned twice
  170. if id_column == time_column:
  171. raise ValueError('One or more columns have been specified to '
  172. 'contain the same data.')
  173. # defining standard column order for internal usage
  174. # [id_column, time_column, value_column1, value_column2, ...]
  175. column_ids = [id_column, time_column]
  176. for i, cid in enumerate(column_ids):
  177. if cid is None:
  178. column_ids[i] = -1
  179. (condition, condition_column, sorting_column) = \
  180. self._get_conditions_and_sorting(id_column, time_column,
  181. gdf_id_list, t_start, t_stop)
  182. data = self.avail_IOs['gdf'].get_columns(
  183. column_ids=column_ids,
  184. condition=condition,
  185. condition_column=condition_column,
  186. sorting_columns=sorting_column)
  187. # create a list of SpikeTrains for all neuron IDs in gdf_id_list
  188. # assign spike times to neuron IDs if id_column is given
  189. if id_column is not None:
  190. if (gdf_id_list == []) and id_column is not None:
  191. gdf_id_list = np.unique(data[:, id_column])
  192. spiketrain_list = []
  193. for nid in gdf_id_list:
  194. selected_ids = self._get_selected_ids(nid, id_column,
  195. time_column, t_start,
  196. t_stop, time_unit, data)
  197. times = data[selected_ids[0]:selected_ids[1], time_column]
  198. spiketrain_list.append(SpikeTrain(
  199. times, units=time_unit,
  200. t_start=t_start, t_stop=t_stop,
  201. id=nid, **args))
  202. # if id_column is not given, all spike times are collected in one
  203. # spike train with id=None
  204. else:
  205. train = data[:, time_column]
  206. spiketrain_list = [SpikeTrain(train, units=time_unit,
  207. t_start=t_start, t_stop=t_stop,
  208. id=None, **args)]
  209. return spiketrain_list
  210. def _check_input_times(self, t_start, t_stop, mandatory=True):
  211. """
  212. Checks input times for existence and setting default values if
  213. necessary.
  214. t_start: pq.quantity.Quantity, start time of the time range to load.
  215. t_stop: pq.quantity.Quantity, stop time of the time range to load.
  216. mandatory: bool, if True times can not be None and an error will be
  217. raised. if False, time values of None will be replaced by
  218. -infinity or infinity, respectively. default: True.
  219. """
  220. if t_stop is None:
  221. if mandatory:
  222. raise ValueError('No t_start specified.')
  223. else:
  224. t_stop = np.inf * pq.s
  225. if t_start is None:
  226. if mandatory:
  227. raise ValueError('No t_stop specified.')
  228. else:
  229. t_start = -np.inf * pq.s
  230. for time in (t_start, t_stop):
  231. if not isinstance(time, pq.quantity.Quantity):
  232. raise TypeError('Time value (%s) is not a quantity.' % time)
  233. return t_start, t_stop
  234. def _check_input_values_parameters(self, value_columns, value_types,
  235. value_units):
  236. """
  237. Checks value parameters for consistency.
  238. value_columns: int, column id containing the value to load.
  239. value_types: list of strings, type of values.
  240. value_units: list of units of the value columns.
  241. Returns
  242. adjusted list of [value_columns, value_types, value_units]
  243. """
  244. if value_columns is None:
  245. raise ValueError('No value column provided.')
  246. if isinstance(value_columns, int):
  247. value_columns = [value_columns]
  248. if value_types is None:
  249. value_types = ['no type'] * len(value_columns)
  250. elif isinstance(value_types, str):
  251. value_types = [value_types]
  252. # translating value types into units as far as possible
  253. if value_units is None:
  254. short_value_types = [vtype.split('_')[0] for vtype in value_types]
  255. if not all([svt in value_type_dict for svt in short_value_types]):
  256. raise ValueError('Can not interpret value types '
  257. '"%s"' % value_types)
  258. value_units = [value_type_dict[svt] for svt in short_value_types]
  259. # checking for same number of value types, units and columns
  260. if not (len(value_types) == len(value_units) == len(value_columns)):
  261. raise ValueError('Length of value types, units and columns does '
  262. 'not match (%i,%i,%i)' % (len(value_types),
  263. len(value_units),
  264. len(value_columns)))
  265. if not all([isinstance(vunit, pq.UnitQuantity) for vunit in
  266. value_units]):
  267. raise ValueError('No value unit or standard value type specified.')
  268. return value_columns, value_types, value_units
  269. def _check_input_gids(self, gid_list, id_column):
  270. """
  271. Checks gid values and column for consistency.
  272. gid_list: list of int or None, gid to load.
  273. id_column: int, id of the column containing the gids.
  274. Returns
  275. adjusted list of [gid_list, id_column].
  276. """
  277. if gid_list is None:
  278. gid_list = [gid_list]
  279. if None in gid_list and id_column is not None:
  280. raise ValueError('No neuron IDs specified but file contains '
  281. 'neuron IDs in column %s. Specify empty list to '
  282. 'retrieve spiketrains of all neurons.'
  283. '' % str(id_column))
  284. if gid_list != [None] and id_column is None:
  285. raise ValueError('Specified neuron IDs to be %s, but no ID column '
  286. 'specified.' % gid_list)
  287. return gid_list, id_column
  288. def _check_input_sampling_period(self, sampling_period, time_column,
  289. time_unit, data):
  290. """
  291. Checks sampling period, times and time unit for consistency.
  292. sampling_period: pq.quantity.Quantity, sampling period of data to load.
  293. time_column: int, column id of times in data to load.
  294. time_unit: pq.quantity.Quantity, unit of time used in the data to load.
  295. data: numpy array, the data to be loaded / interpreted.
  296. Returns
  297. pq.quantities.Quantity object, the updated sampling period.
  298. """
  299. if sampling_period is None:
  300. if time_column is not None:
  301. data_sampling = np.unique(
  302. np.diff(sorted(np.unique(data[:, 1]))))
  303. if len(data_sampling) > 1:
  304. raise ValueError('Different sampling distances found in '
  305. 'data set (%s)' % data_sampling)
  306. else:
  307. dt = data_sampling[0]
  308. else:
  309. raise ValueError('Can not estimate sampling rate without time '
  310. 'column id provided.')
  311. sampling_period = pq.CompoundUnit(str(dt) + '*'
  312. + time_unit.units.u_symbol)
  313. elif not isinstance(sampling_period, pq.UnitQuantity):
  314. raise ValueError("sampling_period is not specified as a unit.")
  315. return sampling_period
  316. def _get_conditions_and_sorting(self, id_column, time_column, gid_list,
  317. t_start, t_stop):
  318. """
  319. Calculates the condition, condition_column and sorting_column based on
  320. other parameters supplied for loading the data.
  321. id_column: int, id of the column containing gids.
  322. time_column: int, id of the column containing times.
  323. gid_list: list of int, gid to be loaded.
  324. t_start: pq.quantity.Quantity, start of the time range to be loaded.
  325. t_stop: pq.quantity.Quantity, stop of the time range to be loaded.
  326. Returns
  327. updated [condition, condition_column, sorting_column].
  328. """
  329. condition, condition_column = None, None
  330. sorting_column = []
  331. curr_id = 0
  332. if ((gid_list != [None]) and (gid_list is not None)):
  333. if gid_list != []:
  334. def condition(x):
  335. return x in gid_list
  336. condition_column = id_column
  337. sorting_column.append(curr_id) # Sorting according to gids first
  338. curr_id += 1
  339. if time_column is not None:
  340. sorting_column.append(curr_id) # Sorting according to time
  341. curr_id += 1
  342. elif t_start != -np.inf and t_stop != np.inf:
  343. warnings.warn('Ignoring t_start and t_stop parameters, because no '
  344. 'time column id is provided.')
  345. if sorting_column == []:
  346. sorting_column = None
  347. else:
  348. sorting_column = sorting_column[::-1]
  349. return condition, condition_column, sorting_column
  350. def _get_selected_ids(self, gid, id_column, time_column, t_start, t_stop,
  351. time_unit, data):
  352. """
  353. Calculates the data range to load depending on the selected gid
  354. and the provided time range (t_start, t_stop)
  355. gid: int, gid to be loaded.
  356. id_column: int, id of the column containing gids.
  357. time_column: int, id of the column containing times.
  358. t_start: pq.quantity.Quantity, start of the time range to load.
  359. t_stop: pq.quantity.Quantity, stop of the time range to load.
  360. time_unit: pq.quantity.Quantity, time unit of the data to load.
  361. data: numpy array, data to load.
  362. Returns
  363. list of selected gids
  364. """
  365. gid_ids = np.array([0, data.shape[0]])
  366. if id_column is not None:
  367. gid_ids = np.array([np.searchsorted(data[:, 0], gid, side='left'),
  368. np.searchsorted(data[:, 0], gid, side='right')])
  369. gid_data = data[gid_ids[0]:gid_ids[1], :]
  370. # select only requested time range
  371. id_shifts = np.array([0, 0])
  372. if time_column is not None:
  373. id_shifts[0] = np.searchsorted(gid_data[:, 1],
  374. t_start.rescale(
  375. time_unit).magnitude,
  376. side='left')
  377. id_shifts[1] = (np.searchsorted(gid_data[:, 1],
  378. t_stop.rescale(
  379. time_unit).magnitude,
  380. side='left') - gid_data.shape[0])
  381. selected_ids = gid_ids + id_shifts
  382. return selected_ids
  383. def read_block(self, gid_list=None, time_unit=pq.ms, t_start=None,
  384. t_stop=None, sampling_period=None, id_column_dat=0,
  385. time_column_dat=1, value_columns_dat=2,
  386. id_column_gdf=0, time_column_gdf=1, value_types=None,
  387. value_units=None, lazy=False):
  388. assert not lazy, 'Do not support lazy'
  389. seg = self.read_segment(gid_list, time_unit, t_start,
  390. t_stop, sampling_period, id_column_dat,
  391. time_column_dat, value_columns_dat,
  392. id_column_gdf, time_column_gdf, value_types,
  393. value_units)
  394. blk = Block(file_origin=seg.file_origin, file_datetime=seg.file_datetime)
  395. blk.segments.append(seg)
  396. seg.block = blk
  397. return blk
  398. def read_segment(self, gid_list=None, time_unit=pq.ms, t_start=None,
  399. t_stop=None, sampling_period=None, id_column_dat=0,
  400. time_column_dat=1, value_columns_dat=2,
  401. id_column_gdf=0, time_column_gdf=1, value_types=None,
  402. value_units=None, lazy=False):
  403. """
  404. Reads a Segment which contains SpikeTrain(s) with specified neuron IDs
  405. from the GDF data.
  406. Arguments
  407. ----------
  408. gid_list : list, default: None
  409. A list of GDF IDs of which to return SpikeTrain(s). gid_list must
  410. be specified if the GDF file contains neuron IDs, the default None
  411. then raises an error. Specify an empty list [] to retrieve the spike
  412. trains of all neurons.
  413. time_unit : Quantity (time), optional, default: quantities.ms
  414. The time unit of recorded time stamps in DAT as well as GDF files.
  415. t_start : Quantity (time), optional, default: 0 * pq.ms
  416. Start time of SpikeTrain.
  417. t_stop : Quantity (time), default: None
  418. Stop time of SpikeTrain. t_stop must be specified, the default None
  419. raises an error.
  420. sampling_period : Quantity (frequency), optional, default: None
  421. Sampling period of the recorded data.
  422. id_column_dat : int, optional, default: 0
  423. Column index of neuron IDs in the DAT file.
  424. time_column_dat : int, optional, default: 1
  425. Column index of time stamps in the DAT file.
  426. value_columns_dat : int, optional, default: 2
  427. Column index of the analog values recorded in the DAT file.
  428. id_column_gdf : int, optional, default: 0
  429. Column index of neuron IDs in the GDF file.
  430. time_column_gdf : int, optional, default: 1
  431. Column index of time stamps in the GDF file.
  432. value_types : str, optional, default: None
  433. Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'
  434. value_units : Quantity (amplitude), default: None
  435. The physical unit of the recorded signal values.
  436. lazy : bool, optional, default: False
  437. Returns
  438. -------
  439. seg : Segment
  440. The Segment contains one SpikeTrain and one AnalogSignal for
  441. each ID in gid_list.
  442. """
  443. assert not lazy, 'Do not support lazy'
  444. if isinstance(gid_list, tuple):
  445. if gid_list[0] > gid_list[1]:
  446. raise ValueError('The second entry in gid_list must be '
  447. 'greater or equal to the first entry.')
  448. gid_list = range(gid_list[0], gid_list[1] + 1)
  449. # __read_xxx() needs a list of IDs
  450. if gid_list is None:
  451. gid_list = [None]
  452. # create an empty Segment
  453. seg = Segment(file_origin=",".join(self.filenames))
  454. seg.file_datetime = datetime.fromtimestamp(os.stat(self.filenames[0]).st_mtime)
  455. # todo: rather than take the first file for the timestamp, we should take the oldest
  456. # in practice, there won't be much difference
  457. # Load analogsignals and attach to Segment
  458. if 'dat' in self.avail_formats:
  459. seg.analogsignals = self.__read_analogsignals(
  460. gid_list,
  461. time_unit,
  462. t_start,
  463. t_stop,
  464. sampling_period=sampling_period,
  465. id_column=id_column_dat,
  466. time_column=time_column_dat,
  467. value_columns=value_columns_dat,
  468. value_types=value_types,
  469. value_units=value_units)
  470. if 'gdf' in self.avail_formats:
  471. seg.spiketrains = self.__read_spiketrains(
  472. gid_list,
  473. time_unit,
  474. t_start,
  475. t_stop,
  476. id_column=id_column_gdf,
  477. time_column=time_column_gdf)
  478. return seg
  479. def read_analogsignal(self, gid=None, time_unit=pq.ms, t_start=None,
  480. t_stop=None, sampling_period=None, id_column=0,
  481. time_column=1, value_column=2, value_type=None,
  482. value_unit=None, lazy=False):
  483. """
  484. Reads an AnalogSignal with specified neuron ID from the DAT data.
  485. Arguments
  486. ----------
  487. gid : int, default: None
  488. The GDF ID of the returned SpikeTrain. gdf_id must be specified if
  489. the GDF file contains neuron IDs, the default None then raises an
  490. error. Specify an empty list [] to retrieve the spike trains of all
  491. neurons.
  492. time_unit : Quantity (time), optional, default: quantities.ms
  493. The time unit of recorded time stamps.
  494. t_start : Quantity (time), optional, default: 0 * pq.ms
  495. Start time of SpikeTrain.
  496. t_stop : Quantity (time), default: None
  497. Stop time of SpikeTrain. t_stop must be specified, the default None
  498. raises an error.
  499. sampling_period : Quantity (frequency), optional, default: None
  500. Sampling period of the recorded data.
  501. id_column : int, optional, default: 0
  502. Column index of neuron IDs.
  503. time_column : int, optional, default: 1
  504. Column index of time stamps.
  505. value_column : int, optional, default: 2
  506. Column index of the analog values recorded.
  507. value_type : str, optional, default: None
  508. Nest data type of the analog values recorded, eg.'V_m', 'I', 'g_e'.
  509. value_unit : Quantity (amplitude), default: None
  510. The physical unit of the recorded signal values.
  511. lazy : bool, optional, default: False
  512. Returns
  513. -------
  514. spiketrain : SpikeTrain
  515. The requested SpikeTrain object with an annotation 'id'
  516. corresponding to the gdf_id parameter.
  517. """
  518. assert not lazy, 'Do not support lazy'
  519. # __read_spiketrains() needs a list of IDs
  520. return self.__read_analogsignals([gid], time_unit,
  521. t_start, t_stop,
  522. sampling_period=sampling_period,
  523. id_column=id_column,
  524. time_column=time_column,
  525. value_columns=value_column,
  526. value_types=value_type,
  527. value_units=value_unit)[0]
  528. def read_spiketrain(
  529. self, gdf_id=None, time_unit=pq.ms, t_start=None, t_stop=None,
  530. id_column=0, time_column=1, lazy=False, **args):
  531. """
  532. Reads a SpikeTrain with specified neuron ID from the GDF data.
  533. Arguments
  534. ----------
  535. gdf_id : int, default: None
  536. The GDF ID of the returned SpikeTrain. gdf_id must be specified if
  537. the GDF file contains neuron IDs. Providing [] loads all available
  538. IDs.
  539. time_unit : Quantity (time), optional, default: quantities.ms
  540. The time unit of recorded time stamps.
  541. t_start : Quantity (time), default: None
  542. Start time of SpikeTrain. t_start must be specified.
  543. t_stop : Quantity (time), default: None
  544. Stop time of SpikeTrain. t_stop must be specified.
  545. id_column : int, optional, default: 0
  546. Column index of neuron IDs.
  547. time_column : int, optional, default: 1
  548. Column index of time stamps.
  549. lazy : bool, optional, default: False
  550. Returns
  551. -------
  552. spiketrain : SpikeTrain
  553. The requested SpikeTrain object with an annotation 'id'
  554. corresponding to the gdf_id parameter.
  555. """
  556. assert not lazy, 'Do not support lazy'
  557. if (not isinstance(gdf_id, int)) and gdf_id is not None:
  558. raise ValueError('gdf_id has to be of type int or None.')
  559. if gdf_id is None and id_column is not None:
  560. raise ValueError('No neuron ID specified but file contains '
  561. 'neuron IDs in column ' + str(id_column) + '.')
  562. return self.__read_spiketrains([gdf_id], time_unit,
  563. t_start, t_stop,
  564. id_column, time_column,
  565. **args)[0]
  566. class ColumnIO:
  567. '''
  568. Class for reading an ASCII file containing multiple columns of data.
  569. '''
  570. def __init__(self, filename):
  571. """
  572. filename: string, path to ASCII file to read.
  573. """
  574. self.filename = filename
  575. # read the first line to check the data type (int or float) of the data
  576. f = open(self.filename)
  577. line = f.readline()
  578. additional_parameters = {}
  579. if '.' not in line:
  580. additional_parameters['dtype'] = np.int32
  581. self.data = np.loadtxt(self.filename, **additional_parameters)
  582. if len(self.data.shape) == 1:
  583. self.data = self.data[:, np.newaxis]
  584. def get_columns(self, column_ids='all', condition=None,
  585. condition_column=None, sorting_columns=None):
  586. """
  587. column_ids : 'all' or list of int, the ids of columns to
  588. extract.
  589. condition : None or function, which is applied to each row to evaluate
  590. if it should be included in the result.
  591. Needs to return a bool value.
  592. condition_column : int, id of the column on which the condition
  593. function is applied to
  594. sorting_columns : int or list of int, column ids to sort by.
  595. List entries have to be ordered by increasing sorting
  596. priority!
  597. Returns
  598. -------
  599. numpy array containing the requested data.
  600. """
  601. if column_ids == [] or column_ids == 'all':
  602. column_ids = range(self.data.shape[-1])
  603. if isinstance(column_ids, (int, float)):
  604. column_ids = [column_ids]
  605. column_ids = np.array(column_ids)
  606. if column_ids is not None:
  607. if max(column_ids) >= len(self.data) - 1:
  608. raise ValueError('Can not load column ID %i. File contains '
  609. 'only %i columns' % (max(column_ids),
  610. len(self.data)))
  611. if sorting_columns is not None:
  612. if isinstance(sorting_columns, int):
  613. sorting_columns = [sorting_columns]
  614. if (max(sorting_columns) >= self.data.shape[1]):
  615. raise ValueError('Can not sort by column ID %i. File contains '
  616. 'only %i columns' % (max(sorting_columns),
  617. self.data.shape[1]))
  618. # Starting with whole dataset being selected for return
  619. selected_data = self.data
  620. # Apply filter condition to rows
  621. if condition and (condition_column is None):
  622. raise ValueError('Filter condition provided, but no '
  623. 'condition_column ID provided')
  624. elif (condition_column is not None) and (condition is None):
  625. warnings.warn('Condition column ID provided, but no condition '
  626. 'given. No filtering will be performed.')
  627. elif (condition is not None) and (condition_column is not None):
  628. condition_function = np.vectorize(condition)
  629. mask = condition_function(
  630. selected_data[:, condition_column]).astype(bool)
  631. selected_data = selected_data[mask, :]
  632. # Apply sorting if requested
  633. if sorting_columns is not None:
  634. values_to_sort = selected_data[:, sorting_columns].T
  635. ordered_ids = np.lexsort(tuple(values_to_sort[i] for i in
  636. range(len(values_to_sort))))
  637. selected_data = selected_data[ordered_ids, :]
  638. # Select only requested columns
  639. selected_data = selected_data[:, column_ids]
  640. return selected_data