example.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import h5py
  2. import scipy.sparse
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. def load_hdf5_array(file_name, key=None, slice=slice(0, None)):
  6. """Function to load data from an hdf file.
  7. Parameters
  8. ----------
  9. file_name: string
  10. hdf5 file name.
  11. key: string
  12. Key name to load. If not provided, all keys will be loaded.
  13. slice: slice, or tuple of slices
  14. Load only a slice of the hdf5 array. It will load `array[slice]`.
  15. Use a tuple of slices to get a slice in multiple dimensions.
  16. Returns
  17. -------
  18. result : array or dictionary
  19. Array, or dictionary of arrays (if `key` is None).
  20. """
  21. with h5py.File(file_name, mode='r') as hf:
  22. if key is None:
  23. data = dict()
  24. for k in hf.keys():
  25. data[k] = hf[k][slice]
  26. return data
  27. else:
  28. return hf[key][slice]
  29. def load_hdf5_sparse_array(file_name, key):
  30. """Load a scipy sparse array from an hdf file
  31. Parameters
  32. ----------
  33. file_name : string
  34. File name containing array to be loaded.
  35. key : string
  36. Name of variable to be loaded.
  37. Notes
  38. -----
  39. This function relies on variables being stored with specific naming
  40. conventions, so cannot be used to load arbitrary sparse arrays.
  41. """
  42. with h5py.File(file_name, mode='r') as hf:
  43. data = (hf['%s_data' % key], hf['%s_indices' % key],
  44. hf['%s_indptr' % key])
  45. sparsemat = scipy.sparse.csr_matrix(data, shape=hf['%s_shape' % key])
  46. return sparsemat
  47. def save_hdf5_dataset(file_name, dataset, mode='w'):
  48. """Save a dataset of arrays and sparse arrays.
  49. Parameters
  50. ----------
  51. file_name : str
  52. Full name of the file.
  53. dataset : dict of arrays
  54. Mappers to save.
  55. mode : str
  56. File opening model.
  57. Use 'w' to write from scratch, 'a' to add to existing file.
  58. """
  59. print("Saving... ", end="", flush=True)
  60. with h5py.File(file_name, mode=mode) as hf:
  61. for name, array in dataset.items():
  62. if scipy.sparse.issparse(array): # sparse array
  63. array = array.tocsr()
  64. hf.create_dataset(name + '_indices', data=array.indices,
  65. compression='gzip')
  66. hf.create_dataset(name + '_data', data=array.data,
  67. compression='gzip')
  68. hf.create_dataset(name + '_indptr', data=array.indptr,
  69. compression='gzip')
  70. hf.create_dataset(name + '_shape', data=array.shape,
  71. compression='gzip')
  72. else: # dense array
  73. hf.create_dataset(name, data=array, compression='gzip')
  74. print("Saved %s" % file_name)
  75. def map_voxels_to_flatmap(voxels, mapper_file):
  76. """Generate flatmap image from voxel array using a mapper.
  77. This function maps an array of voxels into a flattened representation
  78. of an individual subject's brain.
  79. Parameters
  80. ----------
  81. voxels: array of shape (n_voxels, )
  82. Voxel values to be mapped.
  83. mapper_file: string
  84. File containing mapping arrays for a particular subject.
  85. Returns
  86. -------
  87. image : array of shape (width, height)
  88. Flatmap image.
  89. """
  90. voxel_to_flatmap = load_hdf5_sparse_array(mapper_file, 'voxel_to_flatmap')
  91. flatmap_mask = load_hdf5_array(mapper_file, 'flatmap_mask')
  92. badmask = np.array(voxel_to_flatmap.sum(1) > 0).ravel()
  93. img = (np.nan * np.ones(flatmap_mask.shape)).astype(voxels.dtype)
  94. mimg = (np.nan * np.ones(badmask.shape)).astype(voxels.dtype)
  95. mimg[badmask] = (voxel_to_flatmap * voxels.ravel())[badmask].astype(
  96. mimg.dtype)
  97. img[flatmap_mask] = mimg
  98. return img.T[::-1]
  99. def plot_flatmap_from_mapper(voxels, mapper_file, ax=None, alpha=0.7,
  100. cmap='inferno', vmin=None, vmax=None,
  101. with_curvature=True, with_rois=True,
  102. with_colorbar=True,
  103. colorbar_location=(.4, .9, .2, .05)):
  104. """Plot a flatmap from a mapper file.
  105. Note that this function does not have the full capability of pycortex,
  106. (like cortex.quickshow) since it is based on flatmap mappers and not on the
  107. original brain surface of the subject.
  108. Parameters
  109. ----------
  110. voxels : array of shape (n_voxels, )
  111. Data to be plotted.
  112. mapper_file : str
  113. File name of the mapper.
  114. ax : matplotlib Axes or None.
  115. Axes where the figure will be plotted.
  116. If None, a new figure is created.
  117. alpha : float in [0, 1], or array of shape (n_voxels, )
  118. Transparency of the flatmap.
  119. cmap : str
  120. Name of the matplotlib colormap.
  121. vmin : float or None
  122. Minimum value of the colormap. If None, use the 1st percentile of the
  123. `voxels` array.
  124. vmax : float or None
  125. Minimum value of the colormap. If None, use the 99th percentile of the
  126. `voxels` array.
  127. with_curvature : bool
  128. If True, show the curvature below the data layer.
  129. with_rois : bool
  130. If True, show the ROIs labels above the data layer.
  131. colorbar_location : [left, bottom, width, height]
  132. Location of the colorbar. All quantities are in fractions of figure
  133. width and height.
  134. Returns
  135. -------
  136. ax : matplotlib Axes
  137. Axes where the figure has been plotted.
  138. """
  139. # create a figure
  140. if ax is None:
  141. flatmap_mask = load_hdf5_array(mapper_file, key='flatmap_mask')
  142. figsize = np.array(flatmap_mask.shape) / 100.
  143. fig = plt.figure(figsize=figsize)
  144. ax = fig.add_axes((0, 0, 1, 1))
  145. ax.axis('off')
  146. # process plotting parameters
  147. if vmin is None:
  148. vmin = np.percentile(voxels, 1)
  149. if vmax is None:
  150. vmax = np.percentile(voxels, 99)
  151. if isinstance(alpha, np.ndarray):
  152. alpha = map_voxels_to_flatmap(alpha, mapper_file)
  153. # plot the data
  154. image = map_voxels_to_flatmap(voxels, mapper_file)
  155. cimg = ax.imshow(image, aspect='equal', zorder=1, alpha=alpha, cmap=cmap,
  156. vmin=vmin, vmax=vmax)
  157. if with_colorbar:
  158. try:
  159. cbar = ax.inset_axes(colorbar_location)
  160. except AttributeError: # for matplotlib < 3.0
  161. cbar = ax.figure.add_axes(colorbar_location)
  162. ax.figure.colorbar(cimg, cax=cbar, orientation='horizontal')
  163. # plot additional layers if present
  164. with h5py.File(mapper_file, mode='r') as hf:
  165. if with_curvature and "flatmap_curvature" in hf.keys():
  166. curvature = load_hdf5_array(mapper_file, key='flatmap_curvature')
  167. background = np.swapaxes(curvature, 0, 1)[::-1]
  168. else:
  169. background = map_voxels_to_flatmap(np.ones_like(voxels),
  170. mapper_file)
  171. ax.imshow(background, aspect='equal', cmap='gray', vmin=0, vmax=1,
  172. zorder=0)
  173. if with_rois and "flatmap_rois" in hf.keys():
  174. rois = load_hdf5_array(mapper_file, key='flatmap_rois')
  175. ax.imshow(
  176. np.swapaxes(rois, 0, 1)[::-1], aspect='equal',
  177. interpolation='bicubic', zorder=2)
  178. return ax
  179. if __name__ == "__main__":
  180. """
  181. Example for how to load the fMRI test data,
  182. and display the explainable variance on one subject's cortical surface.
  183. More examples at https://github.com/gallantlab/voxelwise_tutorials
  184. """
  185. import os
  186. directory = os.path.abspath('.')
  187. subject = "S01"
  188. # Load fMRI responses on the test set
  189. file_name = os.path.join(directory, 'responses',
  190. f'{subject}_responses.hdf')
  191. Y_test = load_hdf5_array(file_name, "Y_test")
  192. # compute the explainable variance per voxel, based on the test set repeats
  193. mean_var = np.mean(np.var(Y_test, axis=1), axis=0)
  194. var_mean = np.var(np.mean(Y_test, axis=0), axis=0)
  195. explainable_variance = var_mean / mean_var
  196. # Map to subject flatmap
  197. mapper_file = os.path.join(directory, 'mappers', f'{subject}_mappers.hdf')
  198. plot_flatmap_from_mapper(explainable_variance, mapper_file, vmin=0,
  199. vmax=0.7)
  200. plt.show()