123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- # -*- coding: utf-8 -*-
- """
- Created on Tue Nov 12 15:23:26 2024
- @author: arefks
- """
- import os
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- from tqdm import tqdm
- # Define cm for converting cm to inches
- cm = 1 / 2.54
- # Get the directory where the code file is located
- code_dir = os.path.dirname(os.path.abspath(__file__))
- # Get the parent directory of the code directory
- parent_dir = os.path.dirname(code_dir)
- # Define the path for the input CSV files
- original_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'Quantitative_results_from_dwi_processing_only_in_stroke_affected_slices.csv')
- results_file_path = os.path.join(parent_dir, 'output', 'Quantitative_outputs', 'Significance_timepoint_0_vs_28_only_in_stroke_slices.csv')
- # Define the path for the output folders to save plots
- plots_output_dir = os.path.join(parent_dir, 'output', 'Figures', 'pythonFigs')
- fa_over_time_plots_dir = os.path.join(parent_dir, 'output', 'Figures', 'fa_over_time_plots')
- os.makedirs(fa_over_time_plots_dir, exist_ok=True)
- os.makedirs(plots_output_dir, exist_ok=True)
- # Load the original dataset for analysis
- df = pd.read_csv(original_file_path, low_memory=False)
- # Filter out Sham mice
- df = df[df['Group'] != 'Sham']
- # Load the results CSV
- results_df = pd.read_csv(results_file_path)
- # Filter out Sham mice from results_df
- results_df = results_df[results_df['Group'] != 'Sham']
- # Filter results to exclude those with "AMBA" in the mask name
- results_df = results_df[~results_df['mask_name'].str.contains("AMBA")]
- # Define functions to map abbreviations and locations
- def map_abbreviation(mask_name):
- if mask_name.startswith("CC"):
- return "CC"
- elif mask_name.startswith("CRuT"):
- return "CRuT"
- elif mask_name.startswith("CReT"):
- return "CReT"
- elif mask_name.startswith("CST"):
- return "CST"
- elif mask_name.startswith("TC"):
- return "TC"
- elif mask_name.startswith("OT"):
- return "OT"
- else:
- return "Unknown"
- def map_location(mask_name):
- if "ipsi" in mask_name:
- return "Ips"
- elif "contra" in mask_name:
- return "Con"
- else:
- return "None"
- # Add new columns to the dataframe for abbreviation and location
- results_df['abbreviation'] = results_df['mask_name'].apply(map_abbreviation)
- results_df['location'] = results_df['mask_name'].apply(map_location)
- # Get unique abbreviations and locations
- abbreviations = results_df['abbreviation'].unique()
- locations = results_df['location'].unique()
- qtypes = results_df['Qtype'].unique()
- # Define different marker shapes for each unique Qtype
- markers = ['o', 's', '^', 'D']
- marker_mapping = {qtype: markers[i % len(markers)] for i, qtype in enumerate(qtypes)}
- # Flag to toggle displaying the highest point labels
- MaxPlotter = True
- # To store the highest points details for printing at the end
- highest_points_details = []
- # Iterate over each abbreviation and location to create individual plots comparing timepoint 0 vs 28
- for abbr in abbreviations:
- for location in locations:
- subset_df = results_df[(results_df['abbreviation'] == abbr) & (results_df['location'] == location)]
- # Skip if there is no data for the specific abbreviation and location
- if subset_df.empty:
- continue
- # Create a figure for the current abbreviation and location
- plt.figure(figsize=(6 * cm, 6 * cm), dpi=300) # 8 cm by 8 cm in inches, with high DPI for better quality
- mean_diff_list = []
- neg_log_pvalue_list = []
- mask_list = []
- qtype_list = []
- dialation_list = []
- # Iterate over each unique mask in the subset
- for mask in subset_df['mask_name'].unique():
- # Filter original data for timepoints 0 and 28 for the given mask and location
- timepoint_0_df = df[(df['merged_timepoint'] == 0) & (df['mask_name'] == mask)]
- timepoint_28_df = df[(df['merged_timepoint'] == 28) & (df['mask_name'] == mask)]
- # Iterate over each Qtype and dialation_amount amount to calculate mean differences and p-values
- for qtype in qtypes:
- for dialation_amount in subset_df['dialation_amount'].unique():
- tp0_values = timepoint_0_df[(timepoint_0_df['Qtype'] == qtype) & (timepoint_0_df['dialation_amount'] == dialation_amount)]['Value'].dropna()
- tp28_values = timepoint_28_df[(timepoint_28_df['Qtype'] == qtype) & (timepoint_28_df['dialation_amount'] == dialation_amount)]['Value'].dropna()
- # Calculate mean difference
- if len(tp0_values) > 0 and len(tp28_values) > 0:
- mean_diff = tp28_values.mean() - tp0_values.mean()
- else:
- mean_diff = np.nan
- # Get the corresponding p-value from results_df
- pvalue = results_df[(results_df['mask_name'] == mask) &
- (results_df['Qtype'] == qtype) &
- (results_df['dialation_amount'] == dialation_amount)]['Pvalue'].values
- if len(pvalue) > 0:
- neg_log_pvalue = -np.log10(pvalue[0])
- else:
- neg_log_pvalue = np.nan
- mean_diff_list.append(mean_diff)
- neg_log_pvalue_list.append(neg_log_pvalue)
- mask_list.append(mask)
- qtype_list.append(qtype)
- dialation_list.append(dialation_amount)
- # Plot the mean difference vs -log10(Pvalue) with the corresponding marker, combining all qtypes and dilations
- for qtype in qtypes:
- qtype_mean_diff = [mean_diff_list[i] for i in range(len(mean_diff_list)) if qtype_list[i] == qtype]
- qtype_neg_log_pvalue = [neg_log_pvalue_list[i] for i in range(len(neg_log_pvalue_list)) if qtype_list[i] == qtype]
- plt.scatter(qtype_mean_diff, qtype_neg_log_pvalue, alpha=0.7, s=10, marker=marker_mapping[qtype], label=qtype)
- # Add a vertical line at x = 0
- plt.axvline(x=0, color='red', linestyle='--')
- # Labels and title for each plot
- plt.axhline(y=-np.log10(0.05), color='blue', linestyle='--')
- plt.xlabel('Mean Difference (28 - BL)', fontsize=12, fontname='Calibri')
- plt.ylabel('-log10(Pvalue)', fontsize=12, fontname='Calibri')
- plt.title(f'{abbr},{location}', fontsize=12, fontname='Calibri')
- plt.grid(False)
- # Create the legend with marker shapes
- plt.legend(loc='best', fontsize=6, frameon=False)
- # Find and label the highest dot for each Qtype if MaxPlotter is True
- if MaxPlotter and len(mean_diff_list) > 0:
- for qtype in qtypes:
- qtype_indices = [i for i in range(len(mean_diff_list)) if qtype_list[i] == qtype]
- if qtype_indices:
- max_qtype_index = max(qtype_indices, key=lambda i: neg_log_pvalue_list[i])
- plt.text(mean_diff_list[max_qtype_index], neg_log_pvalue_list[max_qtype_index],
- f"{mask_list[max_qtype_index]}, d={dialation_list[max_qtype_index]}",
- fontsize=6, fontname='Calibri', ha='right', va='bottom')
- highest_points_details.append(
- f"Highest point for Qtype {qtype} in {abbr}, {location}: {mask_list[max_qtype_index]}, d={dialation_list[max_qtype_index]}, Mean Diff={mean_diff_list[max_qtype_index]}, -log10(P)={neg_log_pvalue_list[max_qtype_index]}"
- )
- # Save the plot as a PNG file
- plot_file_name = f'volcano_plot_{abbr}_{location}.png'
- plot_file_path = os.path.join(plots_output_dir, plot_file_name)
- plt.savefig(plot_file_path, format='png', bbox_inches='tight')
- plt.show()
- # Print the details of the highest points
- print("\nDetails of the highest points:")
- for detail in highest_points_details:
- print(detail)
|