preprocessing.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import numpy as np
  2. import pandas as pd
  3. def swc_to_df(filepath):
  4. df_swc = pd.read_csv(filepath, delim_whitespace=True, comment='#',
  5. names=['n', 'type', 'x', 'y', 'z', 'radius', 'parent'], index_col=False)
  6. df_swc.index = df_swc.n.values
  7. return df_swc
  8. def read_swc(filepath):
  9. swc = swc_to_df(filepath)
  10. # raw data
  11. n = np.array(swc['n'].tolist())
  12. pos = np.array([swc['x'].tolist(), swc['y'].tolist(), swc['z'].tolist()]).T
  13. radius = np.array(swc['radius'].tolist())
  14. t = np.array(swc['type'].tolist())
  15. pid = np.array(swc['parent'].tolist())
  16. t[pid == -1] = 1 # if soma is missing, the first point is soma
  17. e = np.vstack(list(zip(pid[1:], n[1:])))
  18. e = remove_duplicate(pos, e)
  19. soma_loc = 1
  20. return {'n': n,
  21. 'pos': pos,
  22. 'radius': radius,
  23. 't': t,
  24. 'e': e,
  25. 'soma_loc': soma_loc
  26. }
  27. def read_imx(filepath):
  28. import numpy as np
  29. from bs4 import BeautifulSoup
  30. with open(filepath, 'rb') as f:
  31. xml_binary = f.read().decode('utf-8')
  32. soup = BeautifulSoup(xml_binary, 'lxml')
  33. vertex = soup.find('mfilamentgraphvertex').text
  34. vertex = vertex.replace(')', '')
  35. vertex = vertex.split('(')
  36. vertex = [i.strip() for i in vertex]
  37. vertex = [i for i in vertex if i != '']
  38. edges = soup.find('mfilamentgraphedge').text
  39. edges = edges.replace(')', '')
  40. edges = edges.split('(')
  41. edges = [i.strip() for i in edges]
  42. edges = [i for i in edges if i != '']
  43. vortex = np.array([list(map(float, sub.split(','))) for sub in vertex])
  44. # node position/coordinates(x,y,z)
  45. pos = vortex[:, :3] # n->pos
  46. e = np.array([list(map(int, sub.split(','))) for sub in edges]) + 1
  47. e = remove_duplicate(pos, e)
  48. # node index
  49. n = np.arange(1, len(pos) + 1) # n is the node index
  50. # radius
  51. radius = vortex[:, 3]
  52. # type
  53. t = np.ones(len(radius)) * 3
  54. soma_loc = int(soup.find('mfilamentgraph').attrs['mrootvertex']) + 1
  55. t[soma_loc - 1] = 1
  56. return {'n': n,
  57. 'pos': pos,
  58. 'radius': radius,
  59. 't': t,
  60. 'e': e,
  61. 'soma_loc': soma_loc
  62. }
  63. def remove_duplicate(pos, e):
  64. pos_uniq, pos_count = np.unique(pos, return_counts=True, axis=0)
  65. for duplicate_point in pos_uniq[pos_count > 1]:
  66. a, *rest = np.where((duplicate_point == pos).all(1))[0] + 1
  67. for dupl in rest:
  68. for i in [i for i in np.where(e[:, 0] == dupl)[0]]:
  69. e[i][0] = a
  70. for i in [i for i in np.where(e[:, 1] == dupl)[0]]:
  71. e[i][1] = a
  72. e[(e[:, 1] - e[:, 0]) < 0] = e[(e[:, 1] - e[:, 0]) < 0][:, ::-1]
  73. e = np.delete(e, np.where(e[:, 0] == e[:, 1])[0], axis=0)
  74. e, index = np.unique(e, axis=0, return_index=True)
  75. return e[index.argsort()]
  76. def get_edge_dict(n, e, soma_loc):
  77. def _point_already_in_dict(point, edge_dict):
  78. for path_id, points_list in edge_dict.items():
  79. if point in points_list:
  80. return True
  81. else:
  82. return False
  83. edge_dict = {}
  84. path_id = 1
  85. branch_loc = [i for i in n if len(np.where(e[:, 0] == i)[0]) >= 2]
  86. if 1 not in branch_loc:
  87. branch_loc = [1] + branch_loc
  88. for bpt in branch_loc:
  89. bpt_locs_on_e = np.where(e[:, 0] == bpt)[0]
  90. for edge in e[bpt_locs_on_e]:
  91. current_point = edge[0]
  92. next_point = edge[1]
  93. a_list = [current_point]
  94. if _point_already_in_dict(next_point, edge_dict):
  95. pass
  96. else:
  97. a_list.append(next_point) # a list for holding the index of point of one paths
  98. next_point_locs_on_e = np.where(e[:, 0] == next_point)[0]
  99. while len(next_point_locs_on_e) == 1:
  100. next_point = e[next_point_locs_on_e[0]][1]
  101. if _point_already_in_dict(next_point, edge_dict):
  102. pass
  103. else:
  104. a_list.append(next_point) # a list for holding the index of point of one paths
  105. next_point_locs_on_e = np.where(e[:, 0] == next_point)[0]
  106. if len(a_list) < 2:
  107. pass
  108. continue
  109. edge_dict[path_id] = np.array(a_list)
  110. path_id += 1
  111. if soma_loc not in branch_loc:
  112. paths_soma_on = [key for key, value in edge_dict.items() if soma_loc in value]
  113. for path_id in paths_soma_on:
  114. path = edge_dict[path_id]
  115. breakup_point = np.where(edge_dict[path_id] == soma_loc)[0]
  116. path_0 = path[:breakup_point[0] + 1][::-1]
  117. path_1 = path[breakup_point[0]:]
  118. edge_dict[path_id] = path_0
  119. edge_dict[len(edge_dict)] = path_1
  120. return edge_dict
  121. def get_path_dict(pos, radius, t, edge_dict, soma_loc):
  122. path_dict = {}
  123. radius_dict = {}
  124. type_dict = {}
  125. all_keys = edge_dict.keys()
  126. for key in all_keys:
  127. path_dict[key] = pos[edge_dict[key] - 1]
  128. radius_dict[key] = radius[edge_dict[key] - 1]
  129. type_dict[key] = np.unique(t[edge_dict[key] - 1][1:])[0]
  130. path_dict.update({0: pos[soma_loc - 1].reshape(1, 3)})
  131. radius_dict.update({0: [radius[soma_loc - 1]]})
  132. type_dict.update({0: 1})
  133. df_paths = pd.DataFrame()
  134. df_paths['type'] = pd.Series(type_dict)
  135. df_paths['path'] = pd.Series(path_dict)
  136. df_paths['radius'] = pd.Series(radius_dict)
  137. return df_paths
  138. def sort_path_direction(df_paths):
  139. df_paths = df_paths.copy()
  140. soma = df_paths.loc[0].path.flatten()
  141. df_paths['connect_to'] = np.nan
  142. df_paths['connect_to_at'] = ''
  143. df_paths['connect_to_at'] = df_paths['connect_to_at'].apply(np.array)
  144. path_ids_head = df_paths[df_paths.path.apply(lambda x: (x[0] == soma).all())].index
  145. if len(path_ids_head) > 0:
  146. df_paths.loc[path_ids_head, 'connect_to'] = -1
  147. df_paths.loc[path_ids_head, 'connect_to_at'] = pd.Series({path_id: soma for path_id in path_ids_head})
  148. path_ids_tail = df_paths[df_paths.path.apply(lambda x: (x[-1] == soma).all())].index
  149. if len(path_ids_tail) > 0:
  150. df_paths.loc[path_ids_tail, 'path'] = df_paths.loc[path_ids_tail].path.apply(lambda x: x[::-1])
  151. df_paths.loc[path_ids_tail, 'connect_to'] = -1
  152. df_paths.loc[path_ids_tail, 'connect_to_at'] = pd.Series({path_id: soma for path_id in path_ids_tail})
  153. new_target_paths = list(df_paths[~np.isnan(df_paths.connect_to)].index) # seed the first round of paths to check
  154. while np.count_nonzero(~np.isnan(df_paths.connect_to)) != len(df_paths):
  155. all_checked_paths = list(df_paths[~np.isnan(df_paths.connect_to)].index)
  156. num_check_paths_before = len(all_checked_paths)
  157. target_paths = new_target_paths
  158. new_target_paths = [] # empty the list to hold new target paths for next round
  159. for target_path_id in target_paths:
  160. if target_path_id == 0: continue
  161. target_path = df_paths.loc[target_path_id].path
  162. path_ids_head = df_paths[df_paths.path.apply(lambda x: (x[0] == target_path[-1]).all())].index.tolist()
  163. path_ids_head = [i for i in path_ids_head if i not in all_checked_paths]
  164. if len(path_ids_head) > 0:
  165. df_paths.loc[path_ids_head, 'connect_to'] = target_path_id
  166. df_paths.loc[path_ids_head, 'connect_to_at'] = pd.Series(
  167. {path_id: target_path[-1] for path_id in path_ids_head})
  168. new_target_paths = new_target_paths + path_ids_head
  169. path_ids_tail = df_paths[df_paths.path.apply(lambda x: (x[-1] == target_path[-1]).all())].index.tolist()
  170. path_ids_tail = [i for i in path_ids_tail if i not in all_checked_paths]
  171. if len(path_ids_tail) > 0:
  172. df_paths.loc[path_ids_tail, 'path'] = df_paths.loc[path_ids_tail].path.apply(lambda x: x[::-1])
  173. df_paths.loc[path_ids_tail, 'connect_to'] = target_path_id
  174. df_paths.loc[path_ids_tail, 'connect_to_at'] = pd.Series(
  175. {path_id: target_path[-1] for path_id in path_ids_tail})
  176. new_target_paths = new_target_paths + path_ids_tail
  177. num_check_paths_after = len(list(df_paths[~np.isnan(df_paths.connect_to)].index))
  178. if num_check_paths_before == num_check_paths_after:
  179. num_disconneted = len(df_paths) - num_check_paths_after
  180. break
  181. df_paths_drop = df_paths[np.isnan(df_paths.connect_to)]
  182. df_paths = df_paths.drop(df_paths[np.isnan(df_paths.connect_to)].index)
  183. return df_paths, df_paths_drop
  184. def get_paths_nearest_to_tree(df_paths, df_drops, num_all_paths):
  185. """
  186. Get paths in df_drops which stay nearest to the connected paths.
  187. Paremeters
  188. ----------
  189. df_paths:
  190. DataFrame holding all connected paths
  191. df_drops:
  192. DataFrame holding all disconnected paths
  193. Return
  194. ------
  195. res: dict
  196. {'p0': {'path': path,
  197. 'path_id': path_id_drop,
  198. 'radius': radius,
  199. 'type': t},
  200. 'target': {'path': path_tree,
  201. 'path_id': path_id_tree,
  202. 'radius': radius_tree,
  203. 'type': type_tree
  204. } }
  205. """
  206. res = []
  207. for row in df_drops.iterrows():
  208. path_id_drop = row[0]
  209. path = row[1].path
  210. path_id_tree = df_paths.path.apply(lambda x: np.sqrt(((x[-1] - path) ** 2).sum(1).min())).argmin()
  211. distance_arr = np.sqrt(np.sum((path - df_paths.loc[path_id_tree].path[-1]) ** 2, 1))
  212. loc_nearest = distance_arr.argmin()
  213. dist_nearest = distance_arr.min()
  214. res.append((path_id_drop, path_id_tree, loc_nearest, dist_nearest))
  215. res = np.array(res)
  216. nearest_paths_data = res[np.where(res[:, 3] == res[:, 3].min())[0]]
  217. ### get all paths data
  218. num_paths_to_tree = len(nearest_paths_data)
  219. path_id_tree = nearest_paths_data[0][1]
  220. path_tree = df_paths.loc[path_id_tree].path
  221. radius_tree = df_paths.loc[path_id_tree].radius
  222. type_tree = df_paths.loc[path_id_tree].type
  223. loc_nearest = int(nearest_paths_data[0][2])
  224. distance_between = nearest_paths_data[0][3]
  225. res = {}
  226. for i, datum in enumerate(nearest_paths_data):
  227. path_id_drop = datum[0]
  228. path = df_drops.loc[path_id_drop].path
  229. radius = df_drops.loc[path_id_drop].radius
  230. t = df_drops.loc[path_id_drop].type
  231. if loc_nearest == 0:
  232. path = path
  233. point_connect = path[0]
  234. radius_connect = radius[0]
  235. res['p{}'.format(i)] = {'path': path,
  236. 'path_id': path_id_drop,
  237. 'radius': radius,
  238. 'type': t}
  239. elif loc_nearest == len(path) - 1:
  240. path = path[::-1]
  241. point_connect = path[0]
  242. radius_connect = radius[0]
  243. res['p{}'.format(i)] = {'path': path,
  244. 'path_id': path_id_drop,
  245. 'radius': radius,
  246. 'type': t}
  247. else:
  248. p0 = path[:loc_nearest + 1][::-1]
  249. p1 = path[loc_nearest:]
  250. r0 = radius[:loc_nearest + 1]
  251. r1 = radius[loc_nearest:]
  252. point_connect = p0[0]
  253. radius_connect = r0[0]
  254. if len(p0) > 1:
  255. res['p0'.format(i)] = {'path': p0,
  256. 'path_id': path_id_drop,
  257. 'radius': r0,
  258. 'type': t}
  259. if len(p1) > 1:
  260. res['p1'.format(i)] = {'path': p1,
  261. # 'path_id': len(df_paths) + len(df_drops),
  262. 'path_id': num_all_paths,
  263. 'radius': r1,
  264. 'type': t}
  265. if distance_between > 0:
  266. path_tree = np.vstack([path_tree, point_connect])
  267. radius_tree = np.hstack([radius_tree, radius_connect])
  268. res['target'] = {'path': path_tree,
  269. 'path_id': path_id_tree,
  270. 'radius': radius_tree,
  271. 'type': type_tree
  272. }
  273. return res
  274. def reconnect_dropped_paths(df_paths, df_drops):
  275. df_paths = df_paths.copy()
  276. df_drops = df_drops.copy()
  277. num_dropped_paths = len(df_drops)
  278. while len(df_drops) > 0:
  279. num_all_paths = num_dropped_paths + len(df_paths)
  280. paths_data = get_paths_nearest_to_tree(df_paths, df_drops, num_all_paths)
  281. target = paths_data.pop('target')
  282. path_id_tree = target['path_id']
  283. path_tree = target['path']
  284. radius_tree = target['radius']
  285. tail_path_tree = path_tree[-1]
  286. df_paths.at[int(path_id_tree), 'path'] = path_tree
  287. df_paths.at[int(path_id_tree), 'radius'] = radius_tree
  288. if len(paths_data) == 1:
  289. p = paths_data.pop('p0')
  290. path_id_drop = p['path_id']
  291. path_drop = p['path']
  292. radius_drop = p['radius']
  293. t = p['type']
  294. df_paths.at[int(path_id_tree), 'path'] = np.vstack([path_tree[:-1], path_drop])
  295. df_paths.at[int(path_id_tree), 'radius'] = np.hstack([radius_tree[:-1], radius_drop])
  296. df_drops.drop(path_id_drop, inplace=True)
  297. else:
  298. for key, values in paths_data.items():
  299. p = paths_data[key]
  300. path_id_drop = p['path_id']
  301. path_drop = p['path']
  302. radius_drop = p['radius']
  303. t = p['type']
  304. df_paths.loc[int(path_id_drop)] = [t, path_drop, radius_drop, path_id_tree, tail_path_tree]
  305. try:
  306. df_drops.drop(path_id_drop, inplace=True)
  307. except:
  308. pass
  309. return df_paths.sort_index()
  310. def find_connection(df_paths):
  311. # find all paths connect to current path.
  312. connected_by_dict = {}
  313. connected_by_at_dict = {}
  314. for path_id in df_paths.index:
  315. connected_by_dict[path_id] = df_paths[df_paths.connect_to == path_id].index.tolist()
  316. connected_by_at_dict[path_id] = df_paths[df_paths.connect_to == path_id].connect_to_at.tolist()
  317. df_paths['connected_by'] = pd.Series(connected_by_dict)
  318. df_paths['connected_by_at'] = pd.Series(connected_by_at_dict)
  319. back_to_soma_dict = {}
  320. for path_id in df_paths.index:
  321. list_to_soma = [path_id]
  322. next_path_id = df_paths.loc[path_id].connect_to
  323. while next_path_id != -1:
  324. list_to_soma.append(next_path_id)
  325. next_path_id = df_paths.loc[next_path_id].connect_to
  326. back_to_soma_dict[path_id] = list_to_soma
  327. df_paths['back_to_soma'] = pd.Series(back_to_soma_dict)
  328. return df_paths
  329. def write_swc(df_paths):
  330. path_checked = []
  331. swc_arr = []
  332. list_back_to_soma = (df_paths.sort_values(['connect_to']).back_to_soma).tolist()
  333. for i, back_to_soma in enumerate(list_back_to_soma):
  334. for path_id in back_to_soma[::-1]:
  335. if path_id in path_checked:
  336. continue
  337. path_s = df_paths.loc[path_id]
  338. path = path_s['path']
  339. path_radius = path_s['radius']
  340. if len(path) > 1:
  341. path = path[1:]
  342. path_radius = path_s['radius'][1:]
  343. path_type = path_s['type']
  344. #
  345. connect_to = path_s['connect_to']
  346. connect_to_at = path_s['connect_to_at']
  347. swc_path = np.column_stack([np.ones(len(path)) * path_type, path]) # type
  348. swc_path = np.column_stack([np.arange(len(swc_arr) + 1, len(path) + len(swc_arr) + 1), swc_path]) # ID
  349. swc_path = np.column_stack([swc_path, path_radius * np.ones(len(path))]) # radius
  350. swc_path = np.column_stack([swc_path, swc_path[:, 0] - 1]) # placeholder for PID
  351. if len(swc_arr) == 0:
  352. swc_arr = swc_path
  353. swc_arr[0][-1] = -1
  354. else:
  355. pid = np.where((swc_arr[:, 2:5] == connect_to_at).all(1))[0] + 1
  356. # logging.info(pid)
  357. if len(pid) > 1:
  358. swc_path[0][-1] = pid[0]
  359. # logging.info(swc_arr[pid[0]][2:5], swc_arr[pid[1]][2:5])
  360. elif len(pid) == 1:
  361. swc_path[0][-1] = pid
  362. swc_arr = np.vstack([swc_arr, swc_path])
  363. path_checked.append(path_id)
  364. df_swc = pd.DataFrame(swc_arr)
  365. df_swc.index = np.arange(1, len(df_swc) + 1)
  366. df_swc.columns = ['n', 'type', 'x', 'y', 'z', 'radius', 'parent']
  367. df_swc[['n', 'type', 'parent']] = df_swc[['n', 'type', 'parent']].astype(int)
  368. return df_swc
  369. def data_preprocessing(filepath):
  370. filetype = filepath.split('/')[-1].split('.')[-1].lower()
  371. filename = filepath.split('/')[-1].split('.')[0].lower()
  372. if filetype == 'swc':
  373. data = read_swc(filepath)
  374. elif filetype == 'imx':
  375. data = read_imx(filepath)
  376. else:
  377. raise NotImplementedError
  378. e = data['e']
  379. n = data['n']
  380. t = data['t']
  381. pos = data['pos']
  382. radius = data['radius']
  383. soma_loc = data['soma_loc']
  384. edge_dict = get_edge_dict(n, e, soma_loc)
  385. df_paths = get_path_dict(pos, radius, t, edge_dict, soma_loc)
  386. df_paths, df_paths_drop = sort_path_direction(df_paths)
  387. df_paths = reconnect_dropped_paths(df_paths, df_paths_drop)
  388. df_paths = find_connection(df_paths)
  389. df_swc = write_swc(df_paths)
  390. if (df_paths.iloc[1:].type == 0).all():
  391. type_list = np.ones(len(df_paths)).astype(int) * 3
  392. type_list[0] = 1
  393. df_paths = df_paths.assign(type=type_list)
  394. return df_swc, df_paths