H5_ISI.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # The script calculates the Inter Spike Intervals (ISIs) and plots the
  2. # results in a histogram. The proportion of ISIs violating a threshold
  3. # is also calculated and indicated in the histogram.
  4. # authors: Stefano Diomedi
  5. # date: 10/2022
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from supportFunctions import get_all_data_from_level_h5_rec
  9. import tkinter as tk
  10. from tkinter import filedialog
  11. import os
  12. current_dir = os.getcwd()
  13. # Path to open the file dialog window in the data branch
  14. path_folder_data = os.path.abspath(os.path.join(current_dir, "..", ".."))
  15. path_folder_data = os.path.join(path_folder_data, "data")
  16. # Select H5 dataset file
  17. root = tk.Tk()
  18. root.withdraw()
  19. filename = filedialog.askopenfilename(initialdir=path_folder_data)
  20. group_name = '/DATA' # it can be 'unit_XXX', 'unit_XXX/condition_YY' or 'unit_XXX/condition_YY/trial_ZZ'
  21. threshold = 1 # threshold for violations in ms
  22. markers = []
  23. all_strings_mk = []
  24. spikes = []
  25. all_strings_sp = []
  26. markers, all_strings_mk, spikes, all_strings_sp = get_all_data_from_level_h5_rec.get_all_data_from_level_h5_rec(filename, group_name, markers, all_strings_mk, spikes, all_strings_sp)
  27. ISIs = []
  28. # Iterate over each array of spike data
  29. for spike_data in spikes:
  30. # Calculate the differences for each spike_data array
  31. spike_data=np.array(spike_data)
  32. diff_data = np.diff(spike_data.flatten())
  33. # Append the differences to the list
  34. ISIs.append(diff_data)
  35. # Concatenate all the differences into a single array
  36. ISIs = np.concatenate(ISIs)
  37. # Calculate the violations
  38. violations = (ISIs < threshold).sum() / ISIs.size * 100
  39. # Creation of Histogram
  40. fig, ax = plt.subplots()
  41. time = np.arange(0, 100.5, 0.5)
  42. ax.hist(ISIs, bins=time)
  43. ax.axvline(threshold, linestyle='--')
  44. xLimits = ax.get_xlim()
  45. yLimits = ax.get_ylim()
  46. ax.text(xLimits[1] / 4, yLimits[1] * 2 / 3, f'Violations (ISI below {threshold}ms threshold) = {violations:.2f}%', fontsize=12)
  47. ax.set_xlabel('ISI duration (ms)')
  48. ax.set_ylabel('events num.')
  49. ax.set_title('ISI')
  50. plt.show()