process_result_dataframe.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import argparse
  2. import pandas as pd
  3. import numpy as np
  4. from elephant.statistics import mean_firing_rate
  5. import neo
  6. from pathlib import Path
  7. import os
  8. import sys
  9. sys.path.append(str(Path.cwd().parents[0] / 'scripts'))
  10. from utils import load_df, compact_column
  11. def add_logp_column(df, col_name_str='pvalue', log_prefix='log_'):
  12. col_names = [col for col in df.columns if col_name_str in col]
  13. for col_name in col_names:
  14. pvalue = df[col_name].to_numpy().astype(float)
  15. pvalue[~np.isfinite(pvalue)] = np.nan
  16. df[col_name] = pvalue
  17. log_pvalue = np.array([np.log10(p) if p else np.nan for p in pvalue])
  18. log_pvalue[~np.isfinite(log_pvalue)] = np.nan
  19. df[f'{log_prefix}{col_name}'] = log_pvalue
  20. return df
  21. def calc_ratio(row, numerator, denominator, inf_value=np.nan):
  22. if row[denominator] == 0:
  23. return inf_value
  24. else:
  25. return row[numerator] / row[denominator]
  26. def add_pvalue_ratio(df, ratio_keys=['correlations', 'weights'],
  27. ratio_columns=['pvalue', 'score']):
  28. for col in ratio_columns:
  29. df[f'{col}_ratio'] = df.apply(
  30. lambda row: calc_ratio(row,
  31. f'{col}_{ratio_keys[0]}',
  32. f'{col}_{ratio_keys[1]}'), axis=1)
  33. return df
  34. def add_rate_correlation_column(df, ratecorr_df, on=['protocol', 'seeds']):
  35. return df.merge(ratecorr_df, how='outer', on=None)
  36. def sort_protocol_values(protocol):
  37. if 'ffle' in protocol:
  38. _ , fraction, populations = protocol.split('_')
  39. pop_order = np.array(['E-E', 'E-I', 'I-E', 'I-I'])
  40. value = np.where(populations == pop_order)[0][0]
  41. return float(fraction) + value/100
  42. elif 'add' in protocol:
  43. source, target = protocol.split('_')[-1].split('-')
  44. source, target = source.strip('E'), target.strip('E')
  45. return float(source)*100 + float(target)
  46. else:
  47. return 0
  48. def order_df(df):
  49. key_function = lambda protocols: [sort_protocol_values(p) for p in protocols]
  50. df = df.sort_values('protocol', key=key_function)
  51. return df
  52. if __name__ == '__main__':
  53. CLI = argparse.ArgumentParser()
  54. CLI.add_argument("--output", nargs='?', type=Path)
  55. CLI.add_argument("--comparison_df", nargs='?', type=Path)
  56. CLI.add_argument("--ratecorr_df", nargs='?', type=Path, default=None)
  57. args, unknown = CLI.parse_known_args()
  58. df = load_df(args.comparison_df)
  59. df = compact_column(df)
  60. df = add_pvalue_ratio(df)
  61. df = add_logp_column(df)
  62. if args.ratecorr_df is not None:
  63. df = add_rate_correlation_column(df, load_df(args.ratecorr_df))
  64. df = order_df(df)
  65. df.to_csv(args.output)