create_df_from_portable_dataset.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import pandas as pd
  2. import argparse
  3. import io
  4. import os
  5. from zipfile import ZipFile
  6. import numpy as np
  7. import tqdm
  8. import multiprocessing
  9. import functools
  10. import re
  11. def create_single_table_meas(df_dict):
  12. table_names = ['patients', 'visits', 'tests', 'measurements']
  13. for i, table_name in enumerate(table_names):
  14. if i == 0:
  15. so_far = df_dict[table_name]
  16. else:
  17. new = df_dict[table_name]
  18. so_far = so_far.merge(right=new,
  19. right_on=['%s_uid' % (name[:-1]) for name in table_names[:i]],
  20. left_on=['%s_uid' % (name[:-1]) for name in table_names[:i]],
  21. how='inner')
  22. return so_far
  23. def load_ts(files, archive):
  24. with ZipFile(archive, 'r') as z:
  25. tmp_list = []
  26. for f in files:
  27. # Retrieve index from filename
  28. ix = int(re.findall('[0-9]+', os.path.basename(f))[0])
  29. tmp = np.loadtxt(io.BytesIO(z.read(f)))
  30. tmp_list.append({'ix': ix, 'timeseries': tmp})
  31. return tmp_list
  32. def single_dataframe(archive):
  33. '''Create a single denormalized dataframe containing the patients, visits, tests, and measurements
  34. '''
  35. tables = ['patient', 'visit', 'test', 'measurement', 'edss']
  36. with ZipFile(archive, 'r') as z:
  37. ts = []
  38. # Get all the timeseries
  39. # We'll do this in parallel as loadtxt is rather slow
  40. p = multiprocessing.Pool()
  41. ts_files = [f for f in z.namelist() if '.txt' in f]
  42. func = functools.partial(load_ts, archive=archive)
  43. n_cpu = multiprocessing.cpu_count()
  44. splits = [ts_files[splt[0]:splt[-1]+1] for splt in np.array_split(np.arange(len(ts_files)), min(n_cpu, len(ts_files)))]
  45. ts = p.map(func, splits)
  46. p.close()
  47. # Flatten list
  48. ts = sum(ts, [])
  49. ts_df = pd.DataFrame(ts)
  50. tables_dict = {}
  51. for tab in tables:
  52. tab_file = [k for k in z.namelist() if tab in k][0]
  53. # Explicitly define the date format, as automatic date inference is dodgy
  54. dateparse = lambda x: pd.datetime.strptime(x, '%Y-%m-%d')
  55. df = pd.read_csv(io.BytesIO(z.read(tab_file)))
  56. # Convert the datetime fields
  57. date_fields = [k for k in df.keys() if 'date' in k]
  58. for d in date_fields:
  59. df[d] = pd.to_datetime(df[d], format='%Y-%m-%d')
  60. tables_dict[tab + 's'] = df
  61. df = create_single_table_meas(tables_dict)
  62. cdf = tables_dict['edsss']
  63. assert len(ts_df) == len(df)
  64. tot_df = df.merge(right=ts_df, left_on='timeseries', right_on='ix', how='inner')
  65. tot_df.drop(['timeseries_x', 'ix'], axis=1, inplace=True)
  66. tot_df.rename({'timeseries_y': 'timeseries'}, axis=1, inplace=True)
  67. return tot_df, cdf
  68. if __name__ == "__main__":
  69. parser = argparse.ArgumentParser()
  70. parser.add_argument('archive')
  71. parser.add_argument('-outputfile', default='reconstructed_mep.p')
  72. args = vars(parser.parse_args())
  73. df = single_dataframe(args['archive'])
  74. df.to_pickle(args['outputfile'])