Validation_FOV_alignment.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # ### Define paths and animals for the analysis
  4. # In[1]:
  5. path = '/media/andrey/My Passport/GIN/Anesthesia_CA1/meta_data/meta_recordings_sleep.xlsx'
  6. path4results = '/media/andrey/My Passport/GIN/Anesthesia_CA1/validation/sleep_data_calcium_imaging/' #To store transformation matrix
  7. save_plots_path = '/media/andrey/My Passport/GIN/Anesthesia_CA1/validation/sleep_data_calcium_imaging/'
  8. log_file_path = save_plots_path + 'registration_logs.txt'
  9. animals_for_analysis = [8235,8237,8238]
  10. # ### Align FOV's for all recordings
  11. # In[2]:
  12. repeat_calc = 1
  13. silent_mode = False
  14. #######################
  15. import pandas as pd
  16. import numpy as np
  17. import os
  18. import matplotlib.pyplot as plt
  19. import sys
  20. np.set_printoptions(threshold=sys.maxsize)
  21. from pystackreg import StackReg
  22. # Sobel filter (not used)
  23. #from scipy import ndimage #https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.sobel.html
  24. meta_data = pd.read_excel(path)
  25. #%% compute transformations matrices between recordings
  26. recordings = meta_data['Number']
  27. animals = animals_for_analysis
  28. #print("Recordings: ", recordings)
  29. #ALIGNMENT ALGORITHM
  30. # log file
  31. f = open(log_file_path, "a")
  32. print("n, i, j, rigid_mean_enh, rigid_mean, affine_mean_enh, affine_mean, best_method ", file=f)
  33. if (silent_mode!=True):
  34. print(" RMSD's: (rigid, mean_enh) | (rigid, mean) | (affine, mean_enh) | (affine, mean) | best method ")
  35. '''
  36. for animal in animals:
  37. if (silent_mode!=True):
  38. print("Animal #", str(animal))
  39. if not os.path.exists(path4results + 'StackReg/' +
  40. str(animal) + '.npy') or repeat_calc == 1:
  41. # if not os.path.exists('Q:/Personal/Mattia/Calcium Imaging/results/StackRegEnhImage/' +
  42. # str(animal) + '.npy') or repeat_calc == 1:
  43. # if not os.path.exists('Q:/Personal/Mattia/Calcium Imaging/results/StackReg/' +
  44. # str(animal) + '.npy') or repeat_calc == 1:
  45. meta_animal = meta_data[meta_data['Mouse'] == animal]
  46. recordings = meta_animal['Number']
  47. images_mean = np.zeros((512, 512, np.shape(recordings)[0]))
  48. images_mean_enh = np.zeros((512, 512, np.shape(recordings)[0]))
  49. images_quality_check = np.zeros((512, 512, np.shape(recordings)[0]))
  50. best_tmats = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0], 3, 3))
  51. best_score = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0]))
  52. best_methods = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0]))
  53. tmats_affine = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0], 3, 3))
  54. tmats_rigid = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0], 3, 3))
  55. tmats_affine_enh = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0], 3, 3))
  56. tmats_rigid_enh = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0], 3, 3))
  57. # load all (enhanced) images
  58. for idx, recording in enumerate(recordings):
  59. print(recording)
  60. options = np.load(str(meta_data['Folder'][recording]) +
  61. str(meta_data['Subfolder'][recording]) +
  62. # str(int(meta_data['Recording idx'][recording])) +
  63. '/suite2p/plane0/ops.npy',
  64. allow_pickle=True)
  65. # mean image or mean enhanced image
  66. images_mean[:, :, idx] = options.item(0)['meanImg']
  67. images_mean_enh[:, :, idx] = options.item(0)['meanImgE']
  68. #cut_boarders=50
  69. #quality check
  70. images_quality_check[:, :, idx] = options.item(0)['meanImg']
  71. # loop through every pair and compute the transformation matrix
  72. #conditions = [meta_data['Condition'][recording] for recording in recordings]
  73. for idx0 in range(np.shape(images_mean)[2]):
  74. #if (idx0!=14):
  75. #continue
  76. for idx1 in range(idx0, np.shape(images_mean)[2]):
  77. #if (idx1!=16):
  78. #continue
  79. fraction_of_non_zero_pixels = [0.0,0.0,0.0,0.0]
  80. ### MEAN RIGID and AFFINE
  81. reference_image = images_mean[:, :, idx0]
  82. initial_image = images_mean[:, :, idx1]
  83. #sx = ndimage.sobel(reference_image, axis=0, mode='constant')
  84. #sy = ndimage.sobel(reference_image, axis=1, mode='constant')
  85. #reference_image = np.hypot(sx, sy)
  86. #sx = ndimage.sobel(initial_image, axis=0, mode='constant')
  87. #sy = ndimage.sobel(initial_image, axis=1, mode='constant')
  88. #initial_image = np.hypot(sx, sy)
  89. boarder_cut = 100
  90. sr = StackReg(StackReg.AFFINE)
  91. tmats_affine[idx0, idx1, :, :] = sr.register(reference_image, initial_image)
  92. image_transformed = sr.transform(images_quality_check[:, :, idx1], tmats_affine[idx0, idx1, :, :])
  93. image_difference = images_quality_check[:, :, idx0] - image_transformed
  94. fraction_of_non_zero_pixels[3] = np.count_nonzero(image_transformed[:,:]<0.001)/262144
  95. #plt.imshow(image_transformed)
  96. #plt.show()
  97. image_difference = image_difference[boarder_cut:-boarder_cut, boarder_cut:-boarder_cut]
  98. image_difference = np.square(image_difference)
  99. rmsd_affine = np.sqrt(image_difference.sum()/(512 - 2 * boarder_cut)**2)
  100. if (silent_mode!=True):
  101. print("Fraction of non-zero pixels in 3 (mean affine): ", fraction_of_non_zero_pixels[3]," Score:",rmsd_affine)
  102. sr = StackReg(StackReg.RIGID_BODY)
  103. tmats_rigid[idx0, idx1, :, :] = sr.register(reference_image, initial_image)
  104. image_transformed = sr.transform(images_quality_check[:, :, idx1], tmats_rigid[idx0, idx1, :, :])
  105. image_difference = images_quality_check[:, :, idx0] - image_transformed
  106. fraction_of_non_zero_pixels[1] = np.count_nonzero(image_transformed[:,:]<0.001)/262144
  107. #plt.imshow(image_transformed)
  108. #plt.show()
  109. image_difference = image_difference[boarder_cut:-boarder_cut, boarder_cut:-boarder_cut]
  110. image_difference = np.square(image_difference)
  111. rmsd_rigid = np.sqrt(image_difference.sum()/(512 - 2 * boarder_cut)**2)
  112. if (silent_mode!=True):
  113. print("Fraction of non-zero pixels in 1 (mean rigid): ", fraction_of_non_zero_pixels[1], "Score", rmsd_rigid)
  114. #plt.imshow(image_difference)
  115. ### MEAN_ENH RIGID and AFFINE
  116. reference_image = images_mean_enh[:, :, idx0]
  117. initial_image = images_mean_enh[:, :, idx1]
  118. # sx = ndimage.sobel(reference_image, axis=0, mode='constant')
  119. # sy = ndimage.sobel(reference_image, axis=1, mode='constant')
  120. # reference_image = np.hypot(sx, sy)
  121. # sx = ndimage.sobel(initial_image, axis=0, mode='constant')
  122. # sy = ndimage.sobel(initial_image, axis=1, mode='constant')
  123. # initial_image = np.hypot(sx, sy)
  124. boarder_cut = 100
  125. sr = StackReg(StackReg.AFFINE)
  126. tmats_affine_enh[idx0, idx1, :, :] = sr.register(reference_image, initial_image)
  127. image_transformed = sr.transform(images_quality_check[:, :, idx1], tmats_affine_enh[idx0, idx1, :, :])
  128. image_difference = images_quality_check[:, :, idx0] - image_transformed #TODO: delete image quality check! replace it with meanimage
  129. fraction_of_non_zero_pixels[2] = np.count_nonzero(image_transformed[:,:]<0.001)/262144
  130. #plt.imshow(image_transformed)
  131. #plt.show()
  132. image_difference = image_difference[boarder_cut:-boarder_cut, boarder_cut:-boarder_cut]
  133. image_difference = np.square(image_difference)
  134. rmsd_affine_enh = np.sqrt(image_difference.sum()/(512 - 2 * boarder_cut)**2)
  135. if (silent_mode!=True):
  136. print("Fraction of non-zero pixels in 2 (mean enh affine): ", fraction_of_non_zero_pixels[2],"Score:", rmsd_affine_enh)
  137. sr = StackReg(StackReg.RIGID_BODY)
  138. tmats_rigid_enh[idx0, idx1, :, :] = sr.register(reference_image, initial_image)
  139. image_transformed = sr.transform(images_quality_check[:, :, idx1], tmats_rigid_enh[idx0, idx1, :, :])
  140. image_difference = images_quality_check[:, :, idx0] - image_transformed
  141. fraction_of_non_zero_pixels[0] = np.count_nonzero(image_transformed[:,:]<0.001)/262144
  142. #plt.imshow(image_transformed)
  143. #plt.show()
  144. image_difference = image_difference[boarder_cut:-boarder_cut, boarder_cut:-boarder_cut]
  145. image_difference = np.square(image_difference)
  146. rmsd_rigid_enh = np.sqrt(image_difference.sum()/(512 - 2 * boarder_cut)**2)
  147. if (silent_mode!=True):
  148. print("Fraction of non-zero pixels in 0 (mean enh rigid): ", fraction_of_non_zero_pixels[0],"Score", rmsd_rigid_enh)
  149. rmsds=[rmsd_rigid_enh,rmsd_rigid,rmsd_affine_enh,rmsd_affine]
  150. tmatss=[tmats_rigid_enh[idx0, idx1, :, :],tmats_rigid[idx0, idx1, :, :],tmats_affine_enh[idx0, idx1, :, :],tmats_affine[idx0, idx1, :, :]]
  151. methods=["rigid_mean_enh", "rigid_mean" ,"affine_mean_enh","affine_mean"]
  152. #print(tmats_rigid_enh,tmats_rigid,tmats_affine_enh,tmats_affine)
  153. #print(" ")
  154. #best_method_idx = rmsds.index(min(rmsds))
  155. #smaller_fraction_idx = fraction_of_non_zero_pixels.index(min(fraction_of_non_zero_pixels))
  156. #smaller_fraction_idx = 1
  157. #print(best_method_idx)
  158. #print(smaller_fraction_idx)
  159. list_of_methods=np.argsort(rmsds)
  160. best_score[idx1, idx0] = np.sort(rmsds)[0]
  161. best_score[idx0, idx1] = np.sort(rmsds)[0]
  162. the_best_idx = list_of_methods[0]
  163. if (fraction_of_non_zero_pixels[list_of_methods[0]] > 0.1):
  164. print("Warning: alignment with the best method failed. The second best method is applied")
  165. the_best_idx = list_of_methods[1]
  166. if (fraction_of_non_zero_pixels[list_of_methods[1]] > 0.1):
  167. print("Warning: alignment with the second best method failed. The 3rd best method is applied")
  168. the_best_idx = list_of_methods[2]
  169. best_method = methods[the_best_idx]
  170. best_tmats[idx0, idx1, :, :]=tmatss[the_best_idx]
  171. best_methods[idx1, idx0]=the_best_idx
  172. best_methods[idx0, idx1]=the_best_idx
  173. best_tmats[idx1, idx0, :, :]=np.linalg.inv(best_tmats[idx0, idx1, :, :])
  174. if(idx0==idx1):
  175. best_method="-,-"
  176. if (silent_mode!=True):
  177. print("{0:d}, {1:d}, {2:4.4f}, {3:4.4f}, {4:4.4f}, {5:4.4f}, {6:s}".format(idx0, idx1, rmsd_rigid_enh, rmsd_rigid, rmsd_affine_enh, rmsd_affine, best_method))
  178. print("{0:d}, {1:d}, {2:4.4f}, {3:4.4f}, {4:4.4f}, {5:4.4f}, {6:s}".format(idx0, idx1, rmsd_rigid_enh, rmsd_rigid,
  179. rmsd_affine_enh, rmsd_affine, best_method), file=f)
  180. #print(" " + str(idx0) + "-" + str(idx1) + " " + str(rmsd_rigid_enh) + " " + str(rmsd_rigid) + " " + str(rmsd_affine_enh) + " " + str(rmsd_affine))
  181. # plt.imshow(image_difference)
  182. #plt.savefig(save_plots_path + "StackRegVisualInspection/" + file_title + "_d_reference_m_corrected.png")
  183. #print(str(idx0) + '-' + str(idx1))
  184. # save all the transformation matrices
  185. if not os.path.exists(path4results+'StackReg'):
  186. os.makedirs(path4results+'StackReg')
  187. #print(best_tmats)
  188. np.save(path4results+'StackReg/' + str(animal) + "_best_tmats", best_tmats)
  189. np.save(path4results+'StackReg/' + str(animal) + "_best_methods", best_methods)
  190. np.save(path4results+'StackReg/' + str(animal) + "_best_score", best_score)
  191. # if not os.path.exists('Q:/Personal/Mattia/Calcium Imaging/results/StackRegEnhImage'):
  192. # os.makedirs('Q:/Personal/Mattia/Calcium Imaging/results/StackRegEnhImage')
  193. # np.save('Q:/Personal/Mattia/Calcium Imaging/results/StackRegEnhImage/' + str(animal), tmats)
  194. # if not os.path.exists(save_plots_path+ 'StackRegAffine'):
  195. # os.makedirs(save_plots_path + 'StackRegAffine')
  196. # np.save(save_plots_path+ 'StackRegAffine/' + str(animal), tmats)
  197. f.close()
  198. '''
  199. # ### Install package for image comparison (similarity index)
  200. # In[79]:
  201. #get_ipython().system('conda install -c conda-forge imagehash --yes')
  202. #!pip install ImageHash # as alternative if anaconda is not installed
  203. # In[33]:
  204. from PIL import Image
  205. import imagehash
  206. for animal in animals:
  207. print("Animal #", str(animal))
  208. meta_animal = meta_data[meta_data['Mouse'] == animal]
  209. recordings = meta_animal['Number']
  210. index_similarity = np.zeros((np.shape(recordings)[0], np.shape(recordings)[0]))
  211. cut = 100
  212. images_mean = np.zeros((512, 512, np.shape(recordings)[0]))
  213. images_mean_enh = np.zeros((512, 512, np.shape(recordings)[0]))
  214. # load all (enhanced) images
  215. for idx, recording in enumerate(recordings):
  216. options = np.load(str(meta_data['Folder'][recording]) +
  217. str(meta_data['Subfolder'][recording]) +
  218. #str(int(meta_data['Recording idx'][recording])) +
  219. '/suite2p/plane0/ops.npy',
  220. allow_pickle=True)
  221. # mean image or mean enhanced image
  222. images_mean[:, :, idx] = options.item(0)['meanImg']
  223. images_mean_enh[:, :, idx] = options.item(0)['meanImgE']
  224. for idx0 in range(np.shape(images_mean)[2]):
  225. for idx1 in range(idx0, np.shape(images_mean)[2]):
  226. hash1=imagehash.average_hash(Image.fromarray(images_mean[cut:-cut, cut:-cut, idx0]))
  227. otherhash=imagehash.average_hash(Image.fromarray(images_mean[cut:-cut, cut:-cut, idx1]))
  228. index_similarity[idx0,idx1] = (hash1 - otherhash)
  229. index_similarity[idx1,idx0] = (hash1 - otherhash)
  230. #index_similarity = (np.max(index_similarity)-index_similarity)/np.max(index_similarity)
  231. best_tmats = np.load(path4results+'StackReg/' + str(animal) + "_best_tmats.npy")
  232. best_score = np.load(path4results+'StackReg/' + str(animal) + "_best_score.npy")
  233. metric_best_tmats = np.abs(best_tmats)
  234. metric_best_tmats = metric_best_tmats.sum(axis=(2,3))
  235. metric_best_tmats = np.max(metric_best_tmats) - metric_best_tmats
  236. metric_best_score = (np.max(best_score)-best_score)/np.max(best_score)
  237. #plt.xticks(np.arange(0, np.shape(images_mean)[2], 1));
  238. #plt.yticks(np.arange(0, np.shape(images_mean)[2], 1));
  239. fig, ax = plt.subplots(1,3, figsize=(21, 7))
  240. index_similarity[index_similarity > 30] = 30
  241. image = ax[0].imshow(index_similarity,cmap='viridis_r') #,cmap=cmap, norm=norm
  242. cbar = fig.colorbar(image, ax=ax[0], orientation='horizontal', fraction=.1)
  243. ax[0].set_title("Similarity index", fontsize = 15)
  244. ax[1].imshow(metric_best_tmats)
  245. ax[1].set_title("Displacement index", fontsize = 15)
  246. ax[2].imshow(metric_best_score)
  247. ax[2].set_title("Distance function", fontsize = 15)
  248. fig.suptitle('Animal #' + str(animal), fontsize = 20)
  249. plt.savefig("./FOV_Validation_" + str(animal) + ".png")
  250. # ### Plot all corrected FOV's for comparison (optional)
  251. #
  252. # **The running takes considerable amount of time!**
  253. # In[23]:
  254. '''
  255. for animal in animals:
  256. tmats_loaded = np.load(path4results + 'StackReg/' + str(animal) + "_best_tmats" + '.npy')
  257. meta_animal = meta_data[meta_data['Mouse'] == animal]
  258. recordings = meta_animal['Number']
  259. images = np.zeros((512, 512, np.shape(meta_animal)[0]))
  260. images_mean = np.zeros((512, 512, np.shape(recordings)[0]))
  261. images_mean_enh = np.zeros((512, 512, np.shape(recordings)[0]))
  262. # load all (enhanced) images
  263. for idx, recording in enumerate(recordings):
  264. options = np.load(meta_data['Folder'][recording] +
  265. meta_data['Subfolder'][recording] +
  266. str(int(meta_data['Recording idx'][recording])) +
  267. '/suite2p/plane0/ops.npy',
  268. allow_pickle=True)
  269. # mean image or mean enhanced image
  270. images_mean[:, :, idx] = options.item(0)['meanImg']
  271. images_mean_enh[:, :, idx] = options.item(0)['meanImgE']
  272. #cut_boarders=50
  273. #quality check
  274. #images_quality_check[:, :, idx] = options.item(0)['meanImg']
  275. # loop through every pair and compute the transformation matrix
  276. conditions = [meta_data['Condition'][recording] for recording in recordings]
  277. recording_idx = [meta_data['Recording idx'][recording] for recording in recordings]
  278. for idx0 in range(np.shape(images_mean)[2]):
  279. #if (idx0!=14):
  280. #continue
  281. for idx1 in range(idx0, np.shape(images_mean)[2]):
  282. #if (idx1!=16):
  283. #continue
  284. reference_image = images_mean_enh[:, :, idx0]
  285. initial_image = images_mean_enh[:, :, idx1]
  286. if not os.path.exists(save_plots_path + 'StackRegVisualInspection/'):
  287. os.makedirs(save_plots_path + 'StackRegVisualInspection/')
  288. if not os.path.exists(save_plots_path + 'StackRegVisualInspection/' + str(animal) + '/'):
  289. os.makedirs(save_plots_path + 'StackRegVisualInspection/' + str(animal) + '/')
  290. plt.imshow(reference_image)
  291. # image_title = meta_data['Subfolder'][recording][:-1] + str(meta_data['Recording idx'][recording]) + "\n" + "_condition_" + \
  292. # meta_data['Condition'][recording]
  293. # plt.title(image_title)
  294. #file_title = meta_data['Subfolder'][recording][:-1] + str(
  295. # meta_data['Recording idx'][recording]) + "_condition_" + \
  296. # meta_data['Condition'][recording] + "_" + str(idx0) + "_" + str(idx1)
  297. file_title = str(str(idx0) + '_' + str(idx1) + '_' + conditions[idx0]) + '_' + str(recording_idx[idx0]) + '_' + str(conditions[idx1]) + "_" + str(recording_idx[idx1])
  298. print(file_title)
  299. plt.savefig(save_plots_path + "StackRegVisualInspection/" + str(animal) + '/' + file_title + "_b_reference.png")
  300. #sx = ndimage.sobel(images[:, :, idx1], axis=0, mode='constant')
  301. #sy = ndimage.sobel(images[:, :, idx1], axis=1, mode='constant')
  302. #sob = np.hypot(sx, sy)
  303. #plt.imshow(sob)
  304. plt.imshow(initial_image)
  305. plt.savefig(save_plots_path + "StackRegVisualInspection/" + str(animal) + '/' + file_title + "_a_initial.png")
  306. #grad = np.gradient(images[:, :, idx1])
  307. #print(images[50:55, 50:55, idx1].shape)
  308. #grad = np.gradient(images[50:55, 50:55, idx1])
  309. #print(images[50:55, 50:55, idx1])
  310. #print(" ")
  311. #print(grad)
  312. #image_inversed = sr.transform(reference_image, best_tmats[idx1, idx0, :, :])
  313. #plt.imshow(image_inversed)
  314. #plt.savefig(save_plots_path + "StackRegVisualInspection/" + str(animal) + '/' + file_title + "_1_inversed.png")
  315. #sx = ndimage.sobel(images[:, :, idx1], axis=0, mode='constant')
  316. #sy = ndimage.sobel(images[:, :, idx1], axis=1, mode='constant')
  317. #sob = np.hypot(sx, sy)
  318. #plt.imshow(images[:, :, idx1])
  319. #plt.savefig(save_plots_path + "StackRegVisualInspection/" + file_title + "_sobel.png")
  320. image_corrected = sr.transform(initial_image, tmats_loaded[idx0, idx1, :, :])
  321. plt.imshow(image_corrected)
  322. plt.savefig(save_plots_path + "StackRegVisualInspection/" + str(animal) + '/' + file_title + "_c_corrected.png")
  323. #image_difference = images_quality_check[:, :, idx0] - sr.transform(images_quality_check[:, :, idx1], best_tmats[idx0, idx1, :, :])
  324. #image_difference = reference_image - image_corrected
  325. #plt.imshow(image_difference)
  326. #plt.savefig(save_plots_path + "StackRegVisualInspection/" + str(animal) + '/' + file_title + "_d_reference_m_corrected.png")
  327. # In[9]:
  328. ###TODO Seaborn Scatterplot heatmap could be a nice option for plotting
  329. # In[ ]:
  330. '''