Snakefile 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import sys
  2. import numpy as np
  3. sys.path.append(str(Path.cwd().parents[0] / 'project_utils'))
  4. from project_paths import project_paths
  5. from types import SimpleNamespace
  6. configfile: 'config.yml'
  7. config = SimpleNamespace(**config)
  8. profile_definition_file = project_paths.working_dir \
  9. / 'project_utils' / 'profiles.txt'
  10. with open(profile_definition_file) as file:
  11. PROFILES = [line.strip() for line in file if len(line.strip())]
  12. def select_profile(profile, excludes=[]):
  13. if not type(excludes) == list:
  14. excludes = [excludes]
  15. return np.array([str(exc) not in profile for exc in excludes]).all()
  16. def get_profiles(variant='', excludes=[]):
  17. profiles = [p for p in PROFILES if select_profile(p, excludes)]
  18. if variant:
  19. return [profile+variant for profile in profiles if 'LENS' in profile]
  20. else:
  21. return profiles
  22. wildcard_constraints:
  23. profile = '[\w]+',
  24. measure_type = '[a-z\-]+',
  25. event_name = '[a-z]+',
  26. variant = '\|(macrodim[0-9]+|minimatrigger)|\d?'
  27. def pipeline_output(w):
  28. stage_type = f'{"channel_" if w.measure_type=="channel-wise" else ""}wave'
  29. stage = f'stage05_{stage_type}_characterization'
  30. stage_output = f'{w.event_name}_{w.measure_type}_measures.csv'
  31. path = lambda profile: project_paths.pipeline_output \
  32. / profile / stage / stage_output
  33. return [path(profile) for profile in get_profiles(variant=w.variant,
  34. excludes=config.EXCLUDES)]
  35. def all_input(w):
  36. dfs = expand(project_paths.dataframes
  37. / '{event_name}_{measure_type}{variant}_measures.csv',
  38. measure_type = config.MEASURE_TYPE,
  39. event_name = 'wavefronts',
  40. variant = config.VARIANT)
  41. if 'wavemodes' in config.EVENT_NAME:
  42. dfs += expand(project_paths.dataframes \
  43. / 'wavemodes_{measure_type}_avg_measures.csv',
  44. measure_type = config.MEASURE_TYPE)
  45. if len(config.VARIANT) \
  46. and any('macrodim' in variant for variant in config.VARIANT):
  47. dfs += expand(project_paths.dataframes \
  48. / 'wavefronts_{measure_type}_trend_measures.csv',
  49. measure_type = config.MEASURE_TYPE)
  50. return dfs
  51. rule all:
  52. input:
  53. all_input
  54. rule aggregate_pipeline_output:
  55. input:
  56. script = 'scripts/aggregate_pipeline_output.py',
  57. data = pipeline_output
  58. params:
  59. excludes = config.EXCLUDES
  60. output:
  61. dataframe = project_paths.dataframes \
  62. / '{event_name}_{measure_type}{variant}_measures.csv'
  63. shell:
  64. """
  65. python {input.script} --data {input.data:q} \
  66. --output {output.dataframe:q} \
  67. --excludes {params.excludes:q}
  68. """
  69. rule average_wavemode_measures:
  70. input:
  71. script = 'scripts/average_wavemode_measures.py',
  72. dataframe = '{dir}/wavefronts_{measure_type}{variant}_measures.csv'
  73. output:
  74. dataframe = '{dir}/wavemodes_{measure_type}{variant}_avg_measures.csv'
  75. shell:
  76. """
  77. python {input.script} --dataframe {input.dataframe:q} \
  78. --output {output.dataframe:q}
  79. """
  80. rule macrodim_trends:
  81. input:
  82. script = 'scripts/trend_wavefront_measures.py',
  83. dataframes = expand(str(project_paths.dataframes \
  84. / 'wavefronts_{{measure_type}}{variant}_measures.csv'),
  85. variant = [v for v in config.VARIANT if 'minima' not in v])
  86. params:
  87. groupby = ['anesthetic', 'technique', 'macro_pixel_dim']
  88. output:
  89. dataframe = project_paths.dataframes \
  90. / 'wavefronts_{measure_type}_trend_measures.csv'
  91. shell:
  92. """
  93. python {input.script} --data {input.dataframes:q} \
  94. --groupby {params.groupby:q} \
  95. --output {output.dataframe:q}
  96. """