create_df_from_portable_dataset.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import pandas as pd
  2. import argparse
  3. import io
  4. import os
  5. from zipfile import ZipFile
  6. import numpy as np
  7. import tqdm
  8. import multiprocessing
  9. import functools
  10. import re
  11. def create_single_table_meas(df_dict):
  12. table_names = ['patients', 'visits', 'tests', 'measurements']
  13. for i, table_name in enumerate(table_names):
  14. if i == 0:
  15. so_far = df_dict[table_name]
  16. else:
  17. new = df_dict[table_name]
  18. so_far = so_far.merge(right=new,
  19. right_on=['%s_uid' % (name[:-1]) for name in table_names[:i]],
  20. left_on=['%s_uid' % (name[:-1]) for name in table_names[:i]],
  21. how='inner')
  22. return so_far
  23. def load_ts(uid_files, archive):
  24. uid, files = uid_files
  25. with ZipFile(archive, 'r') as z:
  26. tmp_list = []
  27. for f in tqdm.tqdm(files, disable=(uid != 0)):
  28. # Retrieve index from filename
  29. ix = int(re.findall('[0-9]+', os.path.basename(f))[0])
  30. tmp = np.loadtxt(io.BytesIO(z.read(f)))
  31. tmp_list.append({'ix': ix, 'timeseries': tmp})
  32. return tmp_list
  33. def single_dataframe(archive):
  34. '''Create a single denormalized dataframe containing the patients, visits, tests, and measurements
  35. '''
  36. tables = ['patient', 'visit', 'test', 'measurement', 'edss']
  37. with ZipFile(archive, 'r') as z:
  38. ts = []
  39. # Get all the timeseries
  40. # We'll do this in parallel as loadtxt is rather slow
  41. print('Linking timeseries')
  42. p = multiprocessing.Pool()
  43. ts_files = [f for f in z.namelist() if '.txt' in f]
  44. func = functools.partial(load_ts, archive=archive)
  45. n_cpu = multiprocessing.cpu_count()
  46. splits = [(uid, ts_files[splt[0]:splt[-1]+1]) for uid, splt in enumerate(np.array_split(np.arange(len(ts_files)), min(n_cpu, len(ts_files))))]
  47. ts = p.map(func, splits)
  48. p.close()
  49. # Flatten list
  50. ts = sum(ts, [])
  51. ts_df = pd.DataFrame(ts)
  52. tables_dict = {}
  53. for tab in tables:
  54. tab_file = [k for k in z.namelist() if tab in k][0]
  55. # Explicitly define the date format, as automatic date inference is dodgy
  56. dateparse = lambda x: pd.datetime.strptime(x, '%Y-%m-%d')
  57. df = pd.read_csv(io.BytesIO(z.read(tab_file)))
  58. # Convert the datetime fields
  59. date_fields = [k for k in df.keys() if 'date' in k]
  60. for d in date_fields:
  61. df[d] = pd.to_datetime(df[d], format='%Y-%m-%d')
  62. tables_dict[tab + 's'] = df
  63. df = create_single_table_meas(tables_dict)
  64. cdf = tables_dict['edsss']
  65. assert len(ts_df) == len(df)
  66. tot_df = df.merge(right=ts_df, left_on='timeseries', right_on='ix', how='inner')
  67. tot_df.drop(['timeseries_x', 'ix'], axis=1, inplace=True)
  68. tot_df.rename({'timeseries_y': 'timeseries'}, axis=1, inplace=True)
  69. return tot_df, cdf
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument('archive')
  73. parser.add_argument('-outputfile', default='reconstructed_mep.p')
  74. args = vars(parser.parse_args())
  75. df = single_dataframe(args['archive'])
  76. df.to_pickle(args['outputfile'])