MutualInfoPlot.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import pandas as pd
  2. import matplotlib.pyplot as plt
  3. import seaborn as sns
  4. import os
  5. # Read data from the CSV file
  6. script_dir = os.path.dirname(__file__)
  7. file_path = os.path.join(script_dir, '..', 'input', 'MUplot.csv')
  8. out_path = os.path.join(script_dir, '..', 'figures')
  9. df = pd.read_csv(file_path, header=None, names=['Shift (Voxel)', 'Severe motion', 'No motion'], skiprows=1)
  10. # Set the seaborn style and palette
  11. sns.set(style='ticks', palette='Set1')
  12. # Set font properties
  13. font_properties = {'family': 'Times New Roman', 'size': 8}
  14. font_properties2 = {'family': 'Times New Roman', 'size': 6}
  15. cm = 1/2.54 # centimeters in inches
  16. # Create the plot
  17. fig, ax = plt.subplots(figsize=(7.01*cm, 3.21*cm), dpi=300)
  18. ax.plot(df['Shift (Voxel)'], df['Severe motion'], label='Severe motion', linewidth=1) # Adjust the line width
  19. ax.plot(df['Shift (Voxel)'], df['No motion'], label='No motion', linewidth=1, color='blue') # Adjust the line width
  20. # Set axis labels
  21. ax.set_xlabel('Shift (Voxel)', **font_properties)
  22. ax.set_ylabel('Mutual information (a.u)', **font_properties)
  23. # Set axis ticks font and number of ticks
  24. ax.tick_params(axis='both', which='both', width=0.5, color='gray', length=2)
  25. ax.locator_params(axis='x', nbins=8) # Set the number of ticks for the x-axis
  26. ax.locator_params(axis='y', nbins=8) # Set the number of ticks for the y-axis
  27. for tick in ax.get_xticklabels():
  28. tick.set_fontname('Times New Roman')
  29. tick.set_fontsize(8)
  30. for tick in ax.get_yticklabels():
  31. tick.set_fontname('Times New Roman')
  32. tick.set_fontsize(8)
  33. # Set legend font and remove the legend border
  34. legend = ax.legend(prop=font_properties2, frameon=False)
  35. # Customize the border linewidth
  36. ax.spines['top'].set_linewidth(0.5) # Top border
  37. ax.spines['right'].set_linewidth(0.5) # Right border
  38. ax.spines['bottom'].set_linewidth(0.5) # Bottom border
  39. ax.spines['left'].set_linewidth(0.5) # Left border
  40. # Adjust layout to include labels
  41. plt.subplots_adjust(left=0.15, right=0.95, top=1.1, bottom=0.25)
  42. # Save figures as PNG and SVG with 300 dpi
  43. fig_path_png = os.path.join(out_path, 'MutualInformation.png')
  44. fig_path_svg = os.path.join(out_path, 'MutualInformation.svg')
  45. fig.savefig(fig_path_png, format='png', bbox_inches='tight')
  46. fig.savefig(fig_path_svg, format='svg', bbox_inches='tight')
  47. # Show the plot
  48. plt.show()