highspeed-masks-plot.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # ======================================================================
  4. # SCRIPT INFORMATION:
  5. # ======================================================================
  6. # SCRIPT: PLOT MASKS FOR ILLUSTRATION
  7. # PROJECT: HIGHSPEED
  8. # WRITTEN BY LENNART WITTKUHN, 2020
  9. # CONTACT: WITTKUHN AT MPIB HYPHEN BERLIN DOT MPG DOT DE
  10. # MAX PLANCK RESEARCH GROUP NEUROCODE
  11. # MAX PLANCK INSTITUTE FOR HUMAN DEVELOPMENT
  12. # MAX PLANCK UCL CENTRE FOR COMPUTATIONAL PSYCHIATRY AND AGEING RESEARCH
  13. # LENTZEALLEE 94, 14195 BERLIN, GERMANY
  14. # ======================================================================
  15. # IMPORT RELEVANT PACKAGES
  16. # ======================================================================
  17. from nilearn import plotting
  18. from nilearn import image
  19. import os
  20. from os.path import join as opj
  21. import glob
  22. import re
  23. import matplotlib.pyplot as plt
  24. import statistics
  25. sub = 'sub-01'
  26. path_root = os.getcwd()
  27. path_anat = opj(path_root, 'data', 'bids', sub, '*', 'anat', '*.nii.gz')
  28. anat = sorted(glob.glob(path_anat))[0]
  29. path_patterns = opj(path_root, 'data', 'decoding', sub, 'data', '*.nii.gz')
  30. path_patterns = sorted(glob.glob(path_patterns))
  31. path_patterns = [x for x in path_patterns if 'hpc' not in x]
  32. path_union = opj(path_root, 'data', 'decoding', sub, 'masks', '*union.nii.gz')
  33. path_union = glob.glob(path_union)
  34. # path_fmriprep = opj(path_root, 'derivatives', 'fmriprep')
  35. # path_masks = opj(path_root, 'derivatives', 'decoding', sub)
  36. # path_anat = opj(path_fmriprep, sub, 'anat', sub + '_desc-preproc_T1w.nii.gz')
  37. # path_visual = opj(path_masks, 'masks', '*', '*.nii.gz')
  38. # vis_mask = glob.glob(path_visual)[0]
  39. # vis_mask_smooth = image.smooth_img(vis_mask, 4)
  40. plotting.plot_roi(path_union[0], bg_img=anat,
  41. cut_coords=[30, 10, 15],
  42. title="Region-of-interest", black_bg=True,
  43. display_mode='ortho', cmap='red_transparent',
  44. draw_cross=False)
  45. # calculate average patterns across all trials:
  46. mean_patterns = [image.mean_img(i) for i in path_patterns]
  47. # check the shape of the mean patterns (should be 3D):
  48. [print(image.get_data(i).shape) for i in mean_patterns]
  49. # extract labels of patterns
  50. labels = [re.search('union_(.+?).nii.gz', i).group(1) for i in path_patterns]
  51. # function used to plot individual patterns:
  52. def plot_patterns(pattern, name):
  53. display = plotting.plot_stat_map(
  54. pattern,
  55. bg_img=anat,
  56. #cut_coords=[30, 29, -6],
  57. title=name,
  58. black_bg=True,
  59. colorbar=True,
  60. display_mode='ortho',
  61. draw_cross=False
  62. )
  63. path_save = opj(path_root, 'figures', 'pattern_{}.pdf').format(name)
  64. display.savefig(filename=path_save)
  65. display.close()
  66. return(display)
  67. # plot individual patterns and save coordinates:
  68. coords = []
  69. for pattern, name in zip(mean_patterns, labels):
  70. display = plot_patterns(pattern, name)
  71. coords.append(display.cut_coords)
  72. # mean_coords = [sum(x)/len(x) for x in zip(*coords)]
  73. mean_coords = [statistics.median(x) for x in zip(*coords)]
  74. # create subplot with all patterns using mean coordinates:
  75. fig, axes = plt.subplots(nrows=len(path_patterns), ncols=1, figsize=(14, 20))
  76. for pattern, name, ax in zip(mean_patterns, labels, axes):
  77. display = plotting.plot_stat_map(
  78. pattern, bg_img=anat, cut_coords=mean_coords, title=name,
  79. black_bg=True, colorbar=True, display_mode='ortho',
  80. draw_cross=False, axes=ax, symmetric_cbar=True, vmax=1)
  81. fig.subplots_adjust(wspace=0, hspace=0)
  82. fig.savefig(opj(path_root, 'figures', 'pattern_all.pdf'), bbox_inches='tight')