load_data.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. Load a nix data file from the data set of Henke et al. and make a basic
  3. plot of its contents.
  4. The nix files contain the spike times of a cell recorded in a recording session, stimulus
  5. onset and durations, reproduction onset times and durations, reward information,
  6. feedback range widths, and distance covered by the animal throughout the session.
  7. Use with Python 3.
  8. """
  9. __author__ = ("Kay Thurley")
  10. __version__ = "1.0, November 2021"
  11. import sys
  12. import nixio
  13. import numpy as np
  14. ###################################################### commandline parameters
  15. file_name = sys.argv[1] # second argument should be the name of the file to load
  16. ###################################################### load data
  17. file = nixio.File.open(file_name, nixio.FileMode.ReadOnly)
  18. session = file.blocks[0]
  19. metadata = file.sections[0]
  20. for key, m in metadata.items():
  21. print(key+':', metadata[key])
  22. ###################################################### plotting initialization
  23. import matplotlib as mpl
  24. import matplotlib.pyplot as plt
  25. fontsize = 12.0
  26. markersize = 6
  27. mpl.rcParams['lines.linewidth'] = 2
  28. mpl.rcParams['lines.markersize'] = markersize
  29. mpl.rcParams['font.size'] = fontsize
  30. mpl.rcParams['xtick.direction'] = 'out'
  31. mpl.rcParams['ytick.direction'] = 'out'
  32. ###################################################### plotting
  33. x_axis = session.data_arrays['spike times'].dimensions[0]
  34. spikes = session.data_arrays['spike times'].data[:]
  35. fig = plt.figure(figsize=(11, 3))
  36. ax = fig.add_axes([0.075, 0.25, 0.85, 0.625])
  37. # -- plot spikes
  38. ax.eventplot(spikes, colors=[[0, 0, 0]], lineoffsets=1, linelengths=.25, linewidth=.5, alpha=.5)
  39. # -- mark stimuli / measurement phase
  40. stimuli = session.multi_tags['stimuli']
  41. stimulus_onsets = stimuli.positions[:]
  42. stimulus_durations = stimuli.extents[:]
  43. for i, onset in enumerate(stimulus_onsets):
  44. ax.axvspan(onset, onset+stimulus_durations[i],
  45. color='silver', alpha=0.5, zorder=0, label="measurement phase")
  46. ax.text(onset, 1.2, stimuli.features[0].data[i][0], color="slategray", fontsize=8)
  47. # -- mark reproduction phase
  48. reproductions = session.multi_tags['reproductions']
  49. reproduction_onsets = reproductions.positions[:]
  50. reproduction_durations = reproductions.extents[:]
  51. rewards = session.data_arrays['rewards'].data[:]
  52. for i, onset in enumerate(reproduction_onsets):
  53. # green color for reward, red for no reward
  54. if rewards[i]:
  55. col = 'green'
  56. else:
  57. col = 'red'
  58. ax.axvspan(onset, onset + reproduction_durations[i],
  59. color=col, alpha=0.5, zorder=0, label="reproduction phase")
  60. # -- add speed information
  61. ax2 = ax.twinx()
  62. distances = session.data_arrays['distance covered']
  63. x_axis = distances.dimensions[0]
  64. x = x_axis.axis(distances.data.shape[0])
  65. speed = np.diff(distances.data[:])/np.diff(x)
  66. line_speed, = ax2.plot(x[:-1], speed, '.-', linewidth=1, alpha=0.5, color=np.ones(3)*.5,
  67. markersize=1)
  68. # -- annotate plot
  69. ax.set_yticks([])
  70. ax.set_xlabel(x_axis.label.capitalize() + " (" + x_axis.unit + ")")
  71. ax2.set_ylabel('Speed (' + distances.unit + '/' + x_axis.unit + ')')
  72. ###################################################### finishing
  73. file.close()
  74. plt.show()