GetQuantitativeValues.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from tqdm import tqdm
  8. import argparse
  9. def create_output_dir(output_dir):
  10. """Create the output directory if it does not exist."""
  11. if not os.path.exists(output_dir):
  12. os.makedirs(output_dir)
  13. def process_files(input_path, output_path, create_figures):
  14. """Process DWI files and generate figures."""
  15. # Define a dictionary to map Experiment values to the desired timepoints
  16. session_mapping = {
  17. 0: 0,
  18. 1: 1, 2: 3, 3: 3,
  19. 4: 3, 5: 3, 6: 7, 7: 7,
  20. 8: 7, 9: 7, 10: 7, 11: 14, 12: 14,
  21. 13: 14, 14: 14, 15: 14, 16: 14, 17: 14, 18: 14, 19: 21,
  22. 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21, 26: 28,
  23. 27: 28, 28: 28, 29: 28, 30: 28 , 42:42, 43:42, 56:56, 57:56
  24. }
  25. # Initialize a list to store extracted information
  26. csvdata = []
  27. # Iterate over files with progress bar
  28. file_paths = glob.glob(os.path.join(input_path, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  29. for file_path in tqdm(file_paths, desc="Processing files"):
  30. # Extract information from the file path
  31. subject_id = file_path.split(os.sep)[-5]
  32. time_point = file_path.split(os.sep)[-4]
  33. print(file_path)
  34. int_time_point = 0 if time_point == "ses-Baseline" else int(time_point.split("-P")[1])
  35. merged_time_point = session_mapping[int_time_point]
  36. q_type = os.path.basename(file_path).split("_flipped")[0]
  37. # Create the temp path to search for dwi_masks
  38. temp_path = os.path.join(os.path.dirname(file_path), "RegisteredTractMasks_adjusted")
  39. dwi_masks = glob.glob(os.path.join(temp_path, "*dwi_flipped.nii.gz"))
  40. # Load DWI data using nibabel
  41. dwi_img = nib.load(file_path)
  42. dwi_data = dwi_img.get_fdata()
  43. # Loop through masks and calculate average values
  44. max_pixels = -1
  45. selected_slice = None
  46. for mask_path in tqdm(dwi_masks, desc="Processing masks", leave=False):
  47. mask_name = os.path.basename(mask_path).replace("registered", "").replace("flipped", "").replace(".nii.gz", "").replace("__dwi_","")
  48. mask_img = nib.load(mask_path)
  49. mask_data = mask_img.get_fdata()
  50. # Find the slice with the highest number of pixels
  51. for i in range(mask_data.shape[2]):
  52. num_pixels = np.count_nonzero(mask_data[..., i])
  53. if num_pixels > max_pixels:
  54. max_pixels = num_pixels
  55. selected_slice = i
  56. # Assuming you want to calculate the average value within the mask region
  57. masked_data = dwi_data[mask_data > 0]
  58. average_value = masked_data.mean()
  59. # Create and save the figure if the flag is True
  60. if create_figures:
  61. # Create a plot of the selected slice with overlay
  62. fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=300)
  63. ax.imshow(dwi_data[..., selected_slice], cmap='gray')
  64. ax.imshow(mask_data[..., selected_slice], cmap='jet', alpha=0.5) # Overlay mask_data
  65. ax.set_title(f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}")
  66. plt.axis('off')
  67. # Save the plot to the output path
  68. output_file = os.path.join(output_path,
  69. f"{subject_id}_tp_{merged_time_point}_{q_type}_{average_value:.2f}_{mask_name}.png")
  70. plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
  71. # Close the plot to release memory
  72. plt.close(fig)
  73. # Append the extracted information to the list
  74. csvdata.append([file_path, subject_id, time_point, int_time_point, merged_time_point, q_type, mask_name, average_value])
  75. # Create a DataFrame
  76. df = pd.DataFrame(csvdata,
  77. columns=["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask name",
  78. "Value"])
  79. # Return DataFrame
  80. return df
  81. def main():
  82. # Parse command-line arguments
  83. parser = argparse.ArgumentParser(description="Process DWI files and generate figures.")
  84. parser.add_argument("-i", "--input", type=str, required=True, help="Input directory containing DWI files. e.g: proc_data (flipped fa, md etc file must be available. registering of masks must happen before this step")
  85. parser.add_argument("-o", "--output", type=str, required=True, help="Output directory to save output results (csv and figures).")
  86. parser.add_argument("-f", "--figures", action="store_true", default=False, help="set if you want figures to be saved as pngs")
  87. args = parser.parse_args()
  88. # Create the output directory if it does not exist
  89. create_output_dir(args.output)
  90. # Process files
  91. df = process_files(args.input, args.output,args.figures)
  92. # Save the DataFrame as CSV
  93. csv_file = os.path.join(args.output, "data_summary.csv")
  94. df.to_csv(csv_file, index=False)
  95. print("CSV file created at:", csv_file)
  96. # Print the DataFrame
  97. print(df)
  98. if __name__ == "__main__":
  99. main()