03_storage_to_csv.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """Convert the SQLite feauture storage into .tsv files."""
  2. from ptpython.ipython import embed
  3. from junifer.storage import SQLiteFeatureStorage
  4. from pathlib import Path
  5. import argparse
  6. import sys
  7. sys.path.append("helper_scripts")
  8. from utils import get_marker_names
  9. def parse_args():
  10. """Parse arguments."""
  11. parser = argparse.ArgumentParser(
  12. description="Convert the SQLite feauture storage into .tsv files."
  13. )
  14. parser.add_argument(
  15. "dataset",
  16. type=str,
  17. help=("Which dataset to convert. {'PIOP1', 'PIOP2', 'ID1000'}"),
  18. )
  19. return parser.parse_args()
  20. def validate_args(args):
  21. """Validate arguments."""
  22. datasets = ["ID1000", "PIOP1", "PIOP2"]
  23. assert args.dataset in datasets, (
  24. f"{args.dataset} not a valid dataset! Valid datasets are"
  25. f"{datasets}."
  26. )
  27. return args
  28. def main():
  29. """Convert the SQLite feauture storage into .tsv files."""
  30. datasets = {
  31. "ID1000": ["moviewatching"],
  32. "PIOP2": ["restingstate", "emomatching", "workingmemory"],
  33. "PIOP1": ["restingstate", "faces", "emomatching", "workingmemory"],
  34. }
  35. args = validate_args(parse_args())
  36. dataset = args.dataset
  37. markers = get_marker_names(args.dataset)
  38. storage_path = Path("..") / "junifer_storage" / dataset / dataset
  39. storage = SQLiteFeatureStorage(storage_path, single_output=True)
  40. for marker in markers:
  41. print("loading dataframe...")
  42. connectomes = storage.read_df(feature_name=marker)
  43. print("...done")
  44. for session in datasets[dataset]:
  45. print(dataset, marker, session)
  46. print("reshaping dataframe!")
  47. outfile = (
  48. Path("..")
  49. / "junifer_storage"
  50. / "JUNIFER_AOMIC_TSV_CONNECTOMES"
  51. / f"{dataset}"
  52. / f"{dataset}_{marker}_{session}.tsv.gz"
  53. )
  54. session_connectomes = connectomes.reset_index().drop(columns="idx")
  55. #embed()
  56. if dataset != "ID1000":
  57. session_connectomes = session_connectomes.query(
  58. f"task == '{session}'"
  59. ).drop(columns="task")
  60. subject = session_connectomes["subject"].unique()[0]
  61. # get one correct ordering of all final columns
  62. columns_in_order = (
  63. session_connectomes
  64. .query(f"subject == '{subject}'")
  65. )["pair"]
  66. print(f"{len(columns_in_order)} columns!")
  67. # pivot and reindex
  68. # pivot changes order, so the reindex makes sure the df
  69. # is in the correct order
  70. #embed()
  71. session_connectomes = session_connectomes.pivot(
  72. index="subject", columns="pair", values="0"
  73. ).reindex(columns_in_order, axis=1)
  74. columns = session_connectomes.columns
  75. non_diags = []
  76. for col in columns:
  77. a, b = col.split("~")
  78. if a != b:
  79. non_diags.append(col)
  80. session_connectomes = session_connectomes[non_diags]
  81. new_index = [f"sub-{ind:04}" for ind in session_connectomes.index]
  82. session_connectomes.index = new_index
  83. session_connectomes.to_csv(outfile, sep="\t", compression="gzip")
  84. print("saved to tsv, continue!")
  85. if __name__ == "__main__":
  86. main()