behavior.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from loading import load_session_data
  2. from scipy import signal
  3. import h5py
  4. import numpy as np
  5. import sys, os
  6. sys.path.append(os.path.join(os.getcwd(), '..'))
  7. sys.path.append(os.path.join(os.getcwd(), '..', '..'))
  8. sys.path.append(os.path.join(os.getcwd(), '..', '..', '..', 'pplSIT', 'workflow', 'utils'))
  9. from spatial import gaussian_kernel_2D, get_field_patches
  10. def get_behav_units(session):
  11. session_data = load_session_data(session, load_aeps=False)
  12. moseq_file = session_data['moseq_file']
  13. behav_units = []
  14. for i, (unit, i_rate) in enumerate(session_data['single_units'].items()):
  15. with h5py.File(moseq_file, 'r') as f:
  16. grp = f['units'][unit]
  17. corr_glm_fit_orig = np.array(grp['corr_glm_fit_orig'])
  18. corr_glm_fit_shuffled = np.array(grp['corr_glm_fit_shuffled'])
  19. corr_glm_fit_train_test = np.array(grp['corr_glm_fit_train_test'])
  20. coeffs_shuf = corr_glm_fit_shuffled[corr_glm_fit_shuffled[:, 1] < 0.95][:, 0]
  21. coeffs_chun = corr_glm_fit_train_test[corr_glm_fit_train_test[:, 1] < 0.95][:, 0]
  22. perc_shuf_high = np.percentile(coeffs_shuf, 95)
  23. perc_chun_low = np.percentile(coeffs_chun, 5)
  24. #corr = corr_glm_fit_orig[0]
  25. corr = coeffs_chun.mean()
  26. if perc_shuf_high < perc_chun_low and corr > 0.25:
  27. behav_units.append(unit)
  28. return behav_units
  29. def get_extent(fit, margin=5):
  30. x_range = fit[:, 0].max() - fit[:, 0].min()
  31. y_range = fit[:, 1].max() - fit[:, 1].min()
  32. max_range = np.max([x_range, y_range])
  33. x_min = (fit[:, 0].min() + x_range/2) - max_range*(0.5 + margin/100)
  34. x_max = (fit[:, 0].min() + x_range/2) + max_range*(0.5 + margin/100)
  35. y_min = (fit[:, 1].min() + y_range/2) - max_range*(0.5 + margin/100)
  36. y_max = (fit[:, 1].min() + y_range/2) + max_range*(0.5 + margin/100)
  37. return x_min, x_max, y_min, y_max
  38. def density_map(fit, extent, sigma=0.4, bin_count=100):
  39. pos_range = np.array([[extent[0], extent[1]], [extent[2], extent[3]]])
  40. d_map, xs_edges, ys_edges = np.histogram2d(fit[:, 0], fit[:, 1], bins=[bin_count, bin_count], range=pos_range)
  41. kernel = gaussian_kernel_2D(sigma)
  42. return signal.convolve2d(d_map, kernel, mode='same')
  43. def get_idxs_in_patches(fit, patches, extent, bin_count=100):
  44. x_bins = np.linspace(extent[0], extent[1], bin_count)
  45. y_bins = np.linspace(extent[2], extent[3], bin_count)
  46. idxs_in = []
  47. for i, (x, y) in enumerate(fit):
  48. x_idx = np.argmin(np.abs(x - x_bins))
  49. y_idx = np.argmin(np.abs(y - y_bins))
  50. if patches[x_idx][y_idx] > 0:
  51. idxs_in.append(i)
  52. return np.array(idxs_in)
  53. def get_idxs_behav_state(source, session, idxs_tl_sample, fit_type='tSNE', fit_parm=70, sigma=0.3, margin=10, bin_count=100):
  54. # returns idxs to timeline!
  55. animal = session.split('_')[0]
  56. meta_file = os.path.join(source, animal, session, 'meta.h5')
  57. moseq_class_file = os.path.join(source, animal, session, 'analysis', 'MoSeq_tSNE_UMAP.h5')
  58. with h5py.File(meta_file, 'r') as f:
  59. tl = np.array(f['processed']['timeline'])
  60. tgt_mx = np.array(f['processed']['target_matrix'])
  61. with h5py.File(moseq_class_file, 'r') as f:
  62. idxs_srm_tl = np.array(f['idxs_srm_tl'])
  63. fit = np.array(f[fit_type][str(fit_parm)])
  64. idxs_state = np.array([i for i, x in enumerate(idxs_srm_tl) if x in idxs_tl_sample], dtype=np.int32)
  65. extent = get_extent(fit, margin=margin)
  66. behav_map = density_map(fit[idxs_state], extent, sigma=sigma, bin_count=bin_count)
  67. state_patches = get_field_patches(behav_map, 0.2)
  68. idxs_srm_state = get_idxs_in_patches(fit, state_patches, extent, bin_count=bin_count)
  69. # convert to timeline idxs
  70. bins_to_fill = int((idxs_srm_tl[1] - idxs_srm_tl[0])/2)
  71. idxs_res = []
  72. for idx in idxs_srm_tl[idxs_srm_state]:
  73. idxs_res += list(range(idx - bins_to_fill, idx + bins_to_fill))
  74. idxs_res = np.array(idxs_res, dtype=np.int32)
  75. idxs_res = idxs_res[idxs_res > 0]
  76. idxs_res = idxs_res[idxs_res < len(tl) - 1]
  77. return idxs_res
  78. def get_idxs_neuro_state(source, session, idxs_ev_sample, fit_type='tSNE', fit_parm=70, sigma=0.3, margin=10, bin_count=100):
  79. # returns idxs in sound events space
  80. animal = session.split('_')[0]
  81. meta_file = os.path.join(source, animal, session, 'meta.h5')
  82. umap_file = os.path.join(source, animal, session, 'analysis', 'W1-W4_tSNE_UMAP.h5')
  83. with h5py.File(meta_file, 'r') as f:
  84. tl = np.array(f['processed']['timeline'])
  85. tgt_mx = np.array(f['processed']['target_matrix'])
  86. with h5py.File(umap_file, 'r') as f:
  87. fit = np.array(f[fit_type][str(fit_parm)]) # already in event sampling
  88. extent = get_extent(fit, margin=margin)
  89. selected_map = density_map(fit[idxs_ev_sample], extent, sigma=sigma, bin_count=bin_count)
  90. state_patches = get_field_patches(selected_map, 0.3)
  91. return get_idxs_in_patches(fit, state_patches, extent, bin_count=bin_count)