predict_age_sing.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #from read_data_mask_resampled import *
  2. import sys
  3. print("sys.path:", sys.path)
  4. from brainage.calculate_features import calculate_voxelwise_features
  5. from pathlib import Path
  6. import pandas as pd
  7. import argparse
  8. import pickle
  9. import os
  10. import re
  11. def model_pred(test_df, model_file, feature_space_str):
  12. """This functions predicts age
  13. Args:
  14. test_df (dataframe): test data
  15. model_file (pickle file): trained model file
  16. feature_space_str (string): feature space name
  17. Returns:
  18. dataframe: predictions from the model
  19. """
  20. model = pickle.load(open(model_file, 'rb')) # load model
  21. pred = pd.DataFrame()
  22. for key, model_value in model.items():
  23. X = data_df.columns.tolist()
  24. pre_X, pre_X2 = model_value.preprocess(test_df[X], test_df[X]) # preprocessed data
  25. y_pred = model_value.predict(test_df).ravel()
  26. print(y_pred.shape)
  27. pred[feature_space_str + '+' + key] = y_pred
  28. return pred
  29. if __name__ == "__main__":
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument("--features_path", type=str, help="path to features dir") # eg '../data/ADNI'
  32. parser.add_argument("--data_dir", type=str, help="path to data dir") #
  33. parser.add_argument("--subject_filepaths", type=str, help="path to csv or txt file with subject filepaths") # eg: '../data/ADNI/ADNI.paths_cat12.8.csv'
  34. parser.add_argument("--output_path", type=str, help="path to output_dir") # eg'../results/ADNI'
  35. parser.add_argument("--output_prefix", type=str, help="prefix added to features filename ans results (predictions) file name") # eg: 'ADNI'
  36. parser.add_argument("--mask_file", type=str, help="path to GM mask nii file",
  37. default='../masks/brainmask_12.8.nii')
  38. parser.add_argument("--smooth_fwhm", type=int, help="smoothing FWHM", default=4)
  39. parser.add_argument("--resample_size", type=int, help="resampling kernel size", default=4)
  40. parser.add_argument("--model_file", type=str, help="Trained model to be used to predict",
  41. default='../trained_models/4sites.S4_R4_pca.gauss.models')
  42. # For testing
  43. # python3 predict_age.py --features_path ../data/ADNI --subject_filepaths ../data/ADNI/ADNI.paths_cat12.8.csv --output_path ../results/ADNI --output_prefix ADNI --mask_file ../masks/brainmask_12.8.nii --smooth_fwhm 4 --resample_size 4 --model_file ../trained_models/4sites.S4_R4_pca.gauss.models
  44. args = parser.parse_args()
  45. features_path = args.features_path
  46. data_dir = args.data_dir
  47. subject_filepaths = args.subject_filepaths
  48. output_path = args.output_path
  49. output_prefix = args.output_prefix
  50. smooth_fwhm = args.smooth_fwhm
  51. resample_size = args.resample_size
  52. mask_file = args.mask_file
  53. model_file = args.model_file
  54. print('\nBrain-age trained model used: ', model_file)
  55. print('Data directory (test data): ', data_dir)
  56. print('Subjects filepaths (test data): ', subject_filepaths)
  57. print('Directory to features path: ', features_path)
  58. print('Results directory: ', output_path)
  59. print('Results filename prefix: ', output_prefix)
  60. print('GM mask used: ', mask_file)
  61. # create full filename for the nii files of the subjects and save as csv in features_path
  62. subject_filepaths_nii = pd.read_csv(subject_filepaths, header=None)
  63. subject_filepaths_nii = data_dir + '/' + subject_filepaths_nii
  64. print(subject_filepaths_nii)
  65. subject_full_filepaths = os.path.join(features_path, 'subject_full_filepaths.csv')
  66. print(subject_full_filepaths)
  67. subject_filepaths_nii.to_csv(subject_full_filepaths, header=False, index=False)
  68. # get feature space name from the model file entered and
  69. # create feature space name using the input values (smoothing, resampling)
  70. # match them: they should be same
  71. # get feature space name from the model file entered in argument
  72. pipeline_name1 = model_file.split('/')[-1]
  73. feature_space = pipeline_name1.split('.')[1]
  74. model_name = pipeline_name1.split('.')[2]
  75. pipeline_name = feature_space + '.' + model_name
  76. # create feature space name using the input values (smoothing, resampling)
  77. pca_string = re.findall(r"pca", feature_space)
  78. if len(pca_string) == 1:
  79. feature_space_str = 'S' + str(smooth_fwhm) + '_R' + str(resample_size) + '_pca'
  80. else:
  81. feature_space_str = 'S' + str(smooth_fwhm) + '_R' + str(resample_size)
  82. # match them: they should be same
  83. assert(feature_space_str == feature_space), f"Mismatch in feature parameters entered ({feature_space_str}) & features used for model training ({feature_space})"
  84. print('Feature space: ', feature_space)
  85. print('Model name: ', model_name)
  86. # Create directories, create features if they don't exists
  87. Path(output_path).mkdir(exist_ok=True, parents=True)
  88. Path(features_path).mkdir(exist_ok=True, parents=True)
  89. features_filename = str(output_prefix) + '.S' + str(smooth_fwhm) + '_R' + str(resample_size)
  90. features_fullfile = os.path.join(features_path, features_filename)
  91. print('\nfilename for features created: ', features_fullfile)
  92. if os.path.isfile(features_fullfile): # check if features file exists
  93. print('\n----File exists')
  94. data_df = pickle.load(open(features_fullfile, 'rb'))
  95. print('Features loaded')
  96. else:
  97. print('\n-----Extracting features')
  98. # create features
  99. data_df = calculate_voxelwise_features(subject_full_filepaths, mask_file, smooth_fwhm=smooth_fwhm, resample_size=resample_size)
  100. # save features
  101. pickle.dump(data_df, open(features_fullfile, "wb"), protocol=4)
  102. data_df.to_csv(features_fullfile + '.csv', index=False)
  103. print('Feature extraction done and saved')
  104. # get predictions and save
  105. try:
  106. predictions_df = model_pred(data_df, model_file, feature_space_str)
  107. # save predictions
  108. predictions_filename = str(output_prefix) + '.' + pipeline_name + '.prediction.csv'
  109. predictions_fullfile = os.path.join(output_path, predictions_filename)
  110. print('\nfilename for predictions created: ', predictions_fullfile)
  111. predictions_df.to_csv(predictions_fullfile, index=False)
  112. print(predictions_df)
  113. except FileNotFoundError:
  114. print(f'{model_file} is not present')