extract_atlas_fa_regions.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import glob
  3. import pandas as pd
  4. import nibabel as nib
  5. import numpy as np
  6. from tqdm import tqdm
  7. # Set up paths
  8. code_dir = os.path.dirname(os.path.abspath(__file__))
  9. parent_dir = os.path.dirname(code_dir)
  10. input_address = r"E:\CRC_data\SND\proc_data"
  11. mask_names_file = os.path.join(parent_dir, "input", "acronym_atlas_abriviation.csv")
  12. # Load mask names into a DataFrame
  13. df_maskName = pd.read_csv(mask_names_file)
  14. # Define session mapping
  15. session_mapping = {
  16. 0: 0, 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 7, 7: 7, 8: 7, 9: 7,
  17. 10: 7, 11: 14, 12: 14, 13: 14, 14: 14, 15: 14, 16: 14, 17: 14,
  18. 18: 14, 19: 21, 20: 21, 21: 21, 22: 21, 23: 21, 24: 21, 25: 21,
  19. 26: 28, 27: 28, 28: 28, 29: 28, 30: 28, 42: 42, 43: 42, 56: 56, 57: 56
  20. }
  21. # Initialize list to store extracted information
  22. data = []
  23. # Iterate over files
  24. file_paths = glob.glob(os.path.join(input_address, "**", "dwi", "DSI_studio", "*_flipped.nii.gz"), recursive=True)
  25. for file_path in tqdm(file_paths, desc="Processing files"):
  26. print(file_path)
  27. # Extract information from the file path
  28. subject_id = file_path.split(os.sep)[-5]
  29. time_point = file_path.split(os.sep)[-4]
  30. int_time_point = 0 if (time_point == "ses-Baseline") else int(time_point.split("-P")[1])
  31. merged_time_point = session_mapping.get(int_time_point, 'Unknown')
  32. q_type = os.path.basename(file_path).split("_flipped")[0]
  33. try:
  34. search_stroke = os.path.join(os.path.dirname(file_path), "*StrokeMask_scaled.nii")
  35. stroke_path = glob.glob(search_stroke)[0]
  36. stroke_flag = True
  37. except IndexError:
  38. stroke_path = None
  39. stroke_flag = False
  40. temp_path = os.path.join(os.path.dirname(file_path))
  41. mask_path = os.path.join(temp_path, "*dwiDNSmoothMicoBetAnnoSplit_parental_scaled.nii")
  42. # Load DWI data
  43. dwi_img = nib.load(file_path)
  44. dwi_data = dwi_img.get_fdata()
  45. # Load mask data
  46. mask_files = glob.glob(mask_path)
  47. if not mask_files:
  48. continue
  49. mask_img = nib.load(mask_files[0])
  50. mask_data = mask_img.get_fdata()
  51. unique_masks = np.unique(mask_data)
  52. for rr in unique_masks:
  53. unique_region_mask = mask_data == rr
  54. if stroke_path:
  55. stroke_img = nib.load(stroke_path)
  56. stroke_data = stroke_img.get_fdata()
  57. unique_region_mask = unique_region_mask & ~(stroke_data > 1)
  58. ROI = dwi_data * unique_region_mask
  59. mean_roi_value = ROI[ROI > 0].mean() if np.any(ROI > 0) else np.nan
  60. # Check for the mask name
  61. if rr in df_maskName["RegionID"].values:
  62. mask_name_row = df_maskName[df_maskName["RegionID"] == rr]
  63. mask_name = "L_" + mask_name_row["RegionAbbreviation"].values[0]
  64. elif (rr - 2000) in df_maskName["RegionID"].values:
  65. mask_name_row = df_maskName[df_maskName["RegionID"] == (rr - 2000)]
  66. mask_name = "R_" + mask_name_row["RegionAbbreviation"].values[0]
  67. else:
  68. mask_name = "Unknown"
  69. # Append data to list
  70. data.append([
  71. file_path,
  72. subject_id,
  73. time_point,
  74. int_time_point,
  75. merged_time_point,
  76. q_type,
  77. rr,
  78. mask_name,
  79. mean_roi_value,
  80. "Stroke" if stroke_flag else "Sham"
  81. ])
  82. # Create DataFrame from the collected data
  83. columns = ["fullpath", "subjectID", "timePoint", "int_timepoint", "merged_timepoint", "Qtype", "mask_id", "mask_name", "Value", "Group"]
  84. df_results = pd.DataFrame(data, columns=columns)
  85. # Update Group column based on subject ID
  86. # Check if any entry for a subject is "Stroke", if so, set all entries for that subject to "Stroke"
  87. stroke_subjects = df_results[df_results["Group"] == "Stroke"]["subjectID"].unique()
  88. df_results.loc[df_results["subjectID"].isin(stroke_subjects), "Group"] = "Stroke"
  89. # Define the path for the output CSV file
  90. output_csv_path = os.path.join(parent_dir, "output", "Final_Quantitative_output_for_atalas_regions", "Quantitative_results_from_dwi_processing_atlas.csv")
  91. # Create the directory if it does not exist
  92. output_dir = os.path.dirname(output_csv_path)
  93. if not os.path.exists(output_dir):
  94. os.makedirs(output_dir)
  95. # Save the DataFrame to a CSV file
  96. df_results.to_csv(output_csv_path, index=False)
  97. print(f"Processing complete. Results saved to {output_csv_path}")