plot_corr-of-glm-and-srm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #!/usr/bin/env python3
  2. '''
  3. created on Fri May 21 2021
  4. author: Christian Olaf Haeusler
  5. '''
  6. from glob import glob
  7. import argparse
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import os
  12. import pandas as pd
  13. import re
  14. import seaborn as sns
  15. import brainiak.funcalign.srm
  16. matplotlib.use('Agg')
  17. AO_USED = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18]
  18. AO_NAMES = {1: 'body',
  19. 2: 'bpart',
  20. 3: 'fahead',
  21. 4: 'furn',
  22. 5: 'geo',
  23. 6: 'groom',
  24. 7: 'object',
  25. 8: 'se_new',
  26. 9: 'se_old',
  27. 10: 'sex_f',
  28. 11: 'sex_m',
  29. 12: 'vse_new',
  30. 13: 'vse_old',
  31. 14: 'vlo_ch',
  32. 15: 'vpe_new',
  33. 16: 'vpe_old',
  34. 17: 'fg_ad_lrdiff',
  35. 18: 'fg_ad_rms'
  36. }
  37. AV_USED = [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14]
  38. AV_NAMES = {1: 'vse_new',
  39. 2: 'vse_old',
  40. 3: 'vlo_ch',
  41. 4: 'vpe_new',
  42. 5: 'vpe_old',
  43. 6: 'vno_cut',
  44. 7: 'se_new (ao)',
  45. 8: 'se_old (ao)',
  46. 9: 'fg_av_ger_lr',
  47. 10: 'fg_av_ger_lr_diff',
  48. 11: 'fg_av_ger_ml',
  49. 12: 'fg_av_ger_pd',
  50. 13: 'fg_av_ger_rms',
  51. 14: 'fg_av_ger_ud'
  52. }
  53. VIS_USED = [1, 2, 3, 4, 5, 6]
  54. VIS_NAMES = {1: 'body',
  55. 2: 'face',
  56. 3: 'house',
  57. 4: 'object',
  58. 5: 'scene',
  59. 6: 'scramble'
  60. }
  61. def parse_arguments():
  62. '''
  63. '''
  64. parser = argparse.ArgumentParser(
  65. description="creates the correlation of convoluted regressors from \
  66. a subject's 1st lvl results directories (= all single run dirs ")
  67. parser.add_argument('-ao',
  68. default='inputs/studyforrest-ppa-analysis/'\
  69. 'sub-01/run-1_audio-ppa-grp.feat/design.mat',
  70. help='pattern of path/file for 1st lvl (AO) design files')
  71. parser.add_argument('-vis',
  72. default='inputs/studyforrest-data-visualrois/'\
  73. 'sub-01/run-1.feat/design.mat',
  74. help='pattern of path/file for 1st lvl (vis) design files')
  75. parser.add_argument('-av',
  76. default='inputs/studyforrest-ppa-analysis/'\
  77. 'sub-01/run-1_movie-ppa-grp.feat/design.mat',
  78. help='pattern of path/file for 1st lvl (AV) design files')
  79. parser.add_argument('-model',
  80. default='sub-01/srm-ao-av-vis_feat10-iter30.npz',
  81. help='the model file')
  82. parser.add_argument('-o',
  83. default='test',
  84. help='the output directory for the PDF and SVG file')
  85. args = parser.parse_args()
  86. aoExample = args.ao
  87. avExample = args.av
  88. visExample = args.vis
  89. modelFile = args.model
  90. outDir = args.o
  91. return aoExample, avExample, visExample, modelFile, outDir
  92. def find_files(pattern):
  93. '''
  94. '''
  95. def sort_nicely(l):
  96. '''Sorts a given list in the way that humans expect
  97. '''
  98. convert = lambda text: int(text) if text.isdigit() else text
  99. alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", key)]
  100. l.sort(key=alphanum_key)
  101. return l
  102. found_files = glob(pattern)
  103. found_files = sort_nicely(found_files)
  104. return found_files
  105. def find_design_files(example):
  106. '''
  107. '''
  108. # from example, create the pattern to find design files for all runs
  109. run = re.search('run-\d', example)
  110. run = run.group()
  111. designPattern = example.replace(run, 'run-*')
  112. # just in case, create substitute random subject for sub-01
  113. subj = re.search('sub-\d{2}', example)
  114. subj = subj.group()
  115. designPattern = designPattern.replace(subj, 'sub-01')
  116. designFpathes = sorted(glob(designPattern))
  117. return designFpathes
  118. def load_srm(in_fpath):
  119. # make np.load work with allow_pickle=True
  120. # save np.load
  121. np_load_old = np.load
  122. # modify the default parameters of np.load
  123. np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k)
  124. # np.load = lambda *a: np_load_old(*a, allow_pickle=True)
  125. # load the pickle file
  126. srm = brainiak.funcalign.srm.load(in_fpath)
  127. # change np.load() back to normal
  128. np.load = np_load_old
  129. return srm
  130. def plot_heatmap(title, matrix, outFpath, usedRegressors=[]):
  131. '''
  132. '''
  133. # generate a mask for the upper triangle
  134. mask = np.zeros_like(matrix, dtype=bool)
  135. mask[np.triu_indices_from(mask)] = True
  136. # set up the matplotlib figure
  137. f, ax = plt.subplots(figsize=(11, 9))
  138. # custom diverging colormap
  139. cmap = sns.diverging_palette(220, 10, sep=1, as_cmap=True)
  140. # draw the heatmap with the mask and correct aspect ratio
  141. sns_plot = sns.heatmap(matrix, mask=mask,
  142. cmap=cmap,
  143. square=True,
  144. center=0,
  145. vmin=-1.0, vmax=1,
  146. annot=True, annot_kws={"size": 8, "color": "k"}, fmt='.1f',
  147. # linewidths=.5,
  148. cbar_kws={"shrink": .6}
  149. )
  150. plt.xticks(rotation=90, fontsize=12)
  151. plt.yticks(rotation=0, fontsize=12)
  152. # plt.title(title)
  153. # coloring of ticklabels
  154. # x-axis
  155. if usedRegressors == AO_USED:
  156. # for x in range(len(srm.s_), len(srm.s_) + len(AO_USED)):
  157. for x in range(len(AO_USED)):
  158. plt.gca().get_xticklabels()[x].set_color('blue') # black = default
  159. # handle the sum of regressors
  160. for x in range(len(AO_USED), len(AO_USED) + 1):
  161. plt.gca().get_xticklabels()[x].set_color('cornflowerblue') # black = default
  162. elif usedRegressors == AV_USED:
  163. for x in range(len(AV_USED)):
  164. plt.gca().get_xticklabels()[x].set_color('red') # black = default
  165. elif usedRegressors == VIS_USED:
  166. for x in range(len(VIS_USED)):
  167. plt.gca().get_xticklabels()[x].set_color('y') # black = default
  168. # y-axis
  169. if usedRegressors == AO_USED:
  170. for y in range(len(AO_USED)):
  171. plt.gca().get_yticklabels()[y].set_color('blue') # black = default
  172. for y in range(len(AO_USED), len(AO_USED) + 1):
  173. plt.gca().get_yticklabels()[y].set_color('cornflowerblue') # black = default
  174. elif usedRegressors == AV_USED:
  175. for y in range(len(AV_USED)):
  176. plt.gca().get_yticklabels()[y].set_color('red') # black = default
  177. elif usedRegressors == VIS_USED:
  178. for y in range(len(VIS_USED)):
  179. plt.gca().get_yticklabels()[y].set_color('y') # black = default
  180. # create the output path
  181. os.makedirs(os.path.dirname(outFpath), exist_ok=True)
  182. extensions = ['pdf'] # , 'png', 'svg']
  183. for extension in extensions:
  184. fpath = os.path.join(f'{out_fpath}.{extension}')
  185. plt.savefig(fpath, bbox_inches='tight')
  186. plt.close()
  187. def create_aoDf(aofPathes):
  188. '''
  189. factorize / merge this function and the two functions below into one
  190. '''
  191. # specify which columns of the design file to use
  192. # correct for python index starting at 0
  193. # use every 2nd column because odd numbered columns
  194. # in the design file are temporal derivatives
  195. ao_columns = [(x-1) * 2 for x in AO_USED]
  196. ao_reg_names = [AO_NAMES[x] for x in AO_USED]
  197. # read the 8 design files and concatenate
  198. aoDf = pd.concat([pd.read_csv(run,
  199. usecols=ao_columns,
  200. names=ao_reg_names,
  201. skiprows=5, sep='\t')
  202. for run in aofPathes], ignore_index=True)
  203. # add a combination of regressors
  204. aoDf['geo&groom'] = aoDf['geo'] + aoDf['groom']
  205. # aoDf['geo&groom&furn'] = aoDf['geo'] + aoDf['groom'] + aoDf['furn']
  206. return aoDf
  207. def create_avDf(avfPathes):
  208. '''
  209. '''
  210. # specify which columns of the design file to use
  211. # correct for python index starting at 0
  212. # use every 2nd column because odd numbered columns
  213. # in the design file are temporal derivatives
  214. av_columns = [(x-1) * 2 for x in AV_USED]
  215. av_reg_names = [AV_NAMES[x] for x in AV_USED]
  216. # read the 8 design files and concatenate
  217. avDf = pd.concat([pd.read_csv(run,
  218. usecols=av_columns,
  219. names=av_reg_names,
  220. skiprows=5, sep='\t')
  221. for run in avfPathes], ignore_index=True)
  222. return avDf
  223. def create_visDf(visfPathes):
  224. '''
  225. '''
  226. # specify which columns of the design file to use
  227. # correct for python index starting at 0
  228. # use every 2nd column because odd numbered columns
  229. # in the design file are temporal derivatives
  230. vis_columns = [(x-1) * 2 for x in VIS_USED]
  231. vis_reg_names = [VIS_NAMES[x] for x in VIS_USED]
  232. # read the 4 design files and concatenate
  233. visDf = pd.concat([pd.read_csv(run,
  234. usecols=vis_columns,
  235. names=vis_reg_names,
  236. skiprows=5, sep='\t')
  237. for run in visfPathes], ignore_index=True)
  238. return visDf
  239. def create_srmDf(modelFile):
  240. '''
  241. '''
  242. # load the SRM from file
  243. srm = load_srm(modelFile)
  244. # slice SRM model for the TRs of the audio-description
  245. srm_array = srm.s_.T
  246. # create pandas dataframe from array and name the columns
  247. columns = ['shared feature %s' % str(int(x)+1) for x in range(srm_array.shape[1])]
  248. srmDf = pd.DataFrame(data=srm_array,
  249. columns=columns)
  250. return srmDf
  251. def create_corr_matrix(df1, df2, arctanh=False):
  252. '''
  253. '''
  254. # concat regressors and shared responses
  255. # slice the dataframe cause the last 75 TRs are not in the model space
  256. regressorsAndModelDf = pd.concat([df1, df2], axis=1)
  257. # create the correlation matrix for all columns
  258. if arctanh is True:
  259. regCorrMat = regressorsAndModelDf.corr()
  260. regCorrMat = np.arctanh(regCorrMat)
  261. else:
  262. regCorrMat = regressorsAndModelDf.corr()
  263. return regCorrMat
  264. def handle_one_or_list_of_models(regressorsDf, modelFile, start, end):
  265. '''it's the thought that counts (and that the function does what it does)
  266. '''
  267. if 'shuffled' not in modelFile:
  268. # read the model file
  269. srmDf = create_srmDf(modelFile)
  270. # reset index of SRM's df
  271. srm_ao_TRs = srmDf[start:end]
  272. srm_ao_TRs.reset_index(inplace=True, drop=True)
  273. # concat regressors and shared responses
  274. # slice the dataframe cause the last 75 TRs are not in the model space
  275. regCorrMat = create_corr_matrix(regressorsDf, srm_ao_TRs,
  276. arctanh=False)
  277. else:
  278. # find all files according to pattern
  279. directory = os.path.dirname(modelFile)
  280. modelPattern = model[:-4] + '*.npz'
  281. modelPattern = os.path.join(directory, modelPattern)
  282. shuffledModelFpathes = find_files(modelPattern)
  283. # read in the files and sum up all cells
  284. for idx, shuffledModelFile in enumerate(shuffledModelFpathes, 1):
  285. if idx == 1:
  286. print(f'reading model no. {idx}:', shuffledModelFile)
  287. # read the model file
  288. srmDf = create_srmDf(shuffledModelFile)
  289. # reset index of SRM's df
  290. srm_ao_TRs = srmDf[start:end]
  291. srm_ao_TRs.reset_index(inplace=True, drop=True)
  292. # concat regressors and shared responses
  293. regCorrMat = create_corr_matrix(regressorsDf, srm_ao_TRs,
  294. arctanh=True)
  295. else:
  296. # read the model file
  297. print(f'reading model no. {idx}:', shuffledModelFile)
  298. # reset index of SRM's df
  299. srm_ao_TRs = srmDf[start:end]
  300. srm_ao_TRs.reset_index(inplace=True, drop=True)
  301. # concat regressors and shared responses
  302. regCorrMat += create_corr_matrix(regressorsDf, srm_ao_TRs,
  303. arctanh=True)
  304. # take the mean
  305. regCorrMat = regCorrMat / idx
  306. regCorrMat = np.tanh(regCorrMat)
  307. return regCorrMat
  308. if __name__ == "__main__":
  309. # get the command line inputs
  310. aoExample, avExample, visExample, modelFile, outDir = parse_arguments()
  311. # infere subject number form file name
  312. sub = re.search('sub-\d{2}', modelFile)
  313. sub = sub.group()
  314. model = os.path.basename(modelFile).split('.npz')[0]
  315. # a) plot the correlation of AO regressors and features (only AO TRs)
  316. # get design.mat files for the 8 runs
  317. aofPathes = find_design_files(aoExample)
  318. # create the dataframe
  319. aoDf = create_aoDf(aofPathes)
  320. # indices of audio-decription TRs within the model
  321. start = 0
  322. end = 3524
  323. # get the correlation matrix
  324. regCorrMat = handle_one_or_list_of_models(aoDf, modelFile, start, end)
  325. # plot it
  326. title = f'{sub}: AO Regressors vs. Shared Features'\
  327. ' ({model}; TRs {start}-{end})'
  328. # create name of path and file (must not include ".{extension}"
  329. out_fpath = os.path.join(
  330. outDir, f'corr_ao-regressors-vs-cfs_{sub}_{model}_{start}-{end}')
  331. plot_heatmap(title, regCorrMat, out_fpath, AO_USED)
  332. # b) plot the correlation of AV regressors and features (only AV TRs)
  333. # get design.mat files for the 8 runs
  334. avfPathes = find_design_files(avExample)
  335. # create the dataframe
  336. avDf = create_avDf(avfPathes)
  337. # indices of movie TRs within the model
  338. start = 3524
  339. end = 7123
  340. # get the correlation matrix
  341. regCorrMat = handle_one_or_list_of_models(avDf, modelFile, start, end)
  342. # plot it
  343. title = f'{sub}: AV Regressors vs. Shared Responses' \
  344. ' ({model}; TRs {start}-{end})'
  345. # create name of path and file (must not include file extension)"
  346. out_fpath = os.path.join(
  347. outDir, f'corr_av-regressors-vs-cfs_{sub}_{model}_{start}-{end}'
  348. )
  349. plot_heatmap(title, regCorrMat, out_fpath, AV_USED)
  350. # c) plot the correlation of VIS regressors and features (only VIS TRs)
  351. # get design.mat files for the 4 runs
  352. visfPathes = find_design_files(visExample)
  353. # create the dataframe
  354. visDf = create_visDf(visfPathes)
  355. # indices of localizer TRs within the model
  356. start = 7123
  357. end = 7747
  358. # get the correlation matrix
  359. regCorrMat = handle_one_or_list_of_models(visDf, modelFile, start, end)
  360. # plot it
  361. title = f'{sub}: VIS Regressors vs. Shared Responses' \
  362. ' ({model}; TRs {start}-{end})'
  363. # create name of path and file (must not include file extension)
  364. out_fpath = os.path.join(
  365. outDir, f'corr_vis-regressors-vs-cfs_{sub}_{model}_{start}-{end}'
  366. )
  367. plot_heatmap(title, regCorrMat, out_fpath, VIS_USED)