plotting_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import numpy as np
  2. import pandas as pd
  3. from pathlib import Path
  4. import matplotlib as mpl
  5. int_types = ['wavefronts_id', 'wavemodes_id', 'wavemode', 'channel_id', 'dim_x',
  6. 'dim_y', 'number_of_triggers', 'x_coords', 'y_coords']
  7. float_types = ['spatial_scale', 'sampling_rate', 'recording_length', 'duration',
  8. 'inter_wave_interval', 'inter_wave_interval_std',
  9. 'inter_wave_interval_local', 'planarity', 'velocity_planar',
  10. 'velocity_planar_std', 'velocity_local', 'direction_x',
  11. 'direction_y', 'direction_x_std', 'direction_y_std',
  12. 'direction_local_x', 'direction_local_y', 'time_stamp',
  13. 'flow_direction_local_x', 'flow_direction_local_y']
  14. df_dtypes = dict.fromkeys(int_types, 'Int64') | dict.fromkeys(float_types, float) \
  15. | {'is_planar': bool}
  16. colormap = {'WT': '#35618f',
  17. 'KO': '#72be3e',
  18. 'WBS': '#35618f',
  19. 'FXS': '#72be3e',
  20. 'mix': '#f56808',
  21. 'ketamine': '#dd4601',
  22. 'isoflurane': '#ebd111',
  23. 'light': '#f8765c',
  24. 'medium': '#d3436e',
  25. 'deep': '#982d80'}
  26. def filter_velocity_local(df, vmin=0.0, vmax=120):
  27. # minimum velocity
  28. df = df[df.velocity_local > vmin]
  29. # maximum velocity
  30. df = df[df.velocity_local < vmax]
  31. # nan values
  32. df = df.loc[pd.notnull(df.velocity_local)]
  33. return df
  34. def load_df(filename, dtype=None):
  35. filename = Path(filename)
  36. if dtype is None:
  37. if 'avg' in str(filename.name):
  38. dtype = df_dtypes | dict(wavefronts_id=float,
  39. number_of_triggers=float,
  40. is_planar=float)
  41. else:
  42. dtype = df_dtypes
  43. keep = ['wavefronts_id', 'wavemode', 'wavemodes_id',
  44. 'anesthetic', 'profile', 'technique', 'disease_model', 'model_type']
  45. local_measures = ['velocity_local', 'x_coords', 'y_coords',
  46. 'direction_local_x', 'direction_local_y', 'inter_wave_interval_local',
  47. 'flow_direction_local_x', 'flow_direction_local_y',
  48. 'angle_local', 'flow_angle_local']
  49. global_measures = ['velocity_planar', 'direction_x', 'direction_y',
  50. 'planarity', 'inter_wave_interval', 'number_of_triggers',
  51. 'angle']
  52. usecols = keep + local_measures if 'channel-wise' in str(filename.name) \
  53. else keep + global_measures
  54. df = pd.read_csv(filename, usecols=lambda c: c in usecols,
  55. dtype=dtype, low_memory=False)
  56. for measure in ['direction_local', 'flow_direction_local', 'direction']:
  57. if f'{measure}_x' in df.columns and f'{measure}_y' in df.columns:
  58. df = add_angle_column(df, measure)
  59. for measure in ['wavemode', 'wavemodes_id', 'wavefronts_id']:
  60. if measure in df.columns:
  61. df.drop(df[df[measure] == -1].index, inplace=True)
  62. df = simplify_anesthetic_names(df)
  63. return df
  64. def simplify_anesthetic_names(df):
  65. df['anesthetic'] = ['ketamine' if 'ketamine' in name else name
  66. for name in df.anesthetic]
  67. return df
  68. def direction_to_angle(dx, dy):
  69. return np.angle(dx + 1j*dy)
  70. def add_angle_column(df, dir_key='direction', angle_key=None):
  71. if angle_key is None:
  72. angle_key = dir_key.replace('direction', 'angle')
  73. df[angle_key] = df.apply(lambda row: direction_to_angle(row[f'{dir_key}_x'],
  74. row[f'{dir_key}_y']),
  75. axis=1)
  76. return df
  77. def color_to_map(hex_color, N):
  78. h = hex_color.lstrip('#')
  79. r, g, b = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
  80. M = N+2
  81. vals = np.ones((M, 4))
  82. vals[:, 0] = np.linspace(r/256, 1, M)
  83. vals[:, 1] = np.linspace(g/256, 1, M)
  84. vals[:, 2] = np.linspace(b/256, 1, M)
  85. return mpl.colors.ListedColormap(vals).colors[:N]