vizNrn2DDensity.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from matplotlib import pyplot as plt
  2. import numpy as np
  3. import seaborn as sns
  4. from matplotlib import patches as mpatches
  5. import colorsys
  6. plt.ion()
  7. import os
  8. from regmaxsn.core.matplotlibRCParams import mplPars
  9. homeFolder = "/home/aj/"
  10. sns.set(rc=mplPars, style='whitegrid')
  11. def getLighterColor(col, saturation):
  12. '''
  13. Returns a color with the same hue and value as the color `col', but with the given saturation
  14. :param col: 3 member iterable with values in [0, 1]
  15. :param saturation: float in [0, 1]
  16. :return:
  17. '''
  18. assert len(col) == 3, 'col must be a 3 member iterable'
  19. assert all([0 <= x <= 1 for x in col]), 'col can only contain values in [0, 1]'
  20. assert 0 <= saturation <= 1, 'saturation must be in [0, 1]'
  21. hsv = colorsys.rgb_to_hsv(*col)
  22. return colorsys.hsv_to_rgb(hsv[0], saturation, hsv[2])
  23. swcFiles = [
  24. # os.path.join(homeFolder,
  25. # 'DataAndResults/morphology/OriginalData/Tests/HSN-fluoro01.CNG.swc'),
  26. # os.path.join(homeFolder,
  27. # 'DataAndResults/morphology/OriginalData/Tests/HSN-fluoro01.CNGRandRotY0.swc'),
  28. # os.path.join(homeFolder,
  29. # 'DataAndResults/morphology/OriginalData/Tests/HSN-fluoro01.CNGRandRotY1.swc'),
  30. os.path.join(homeFolder, 'DataAndResults/morphology/OriginalData/chiangAA1', "VGlut-F-300181.CNG.swc"),
  31. os.path.join(homeFolder, 'DataAndResults/morphology/OriginalData/chiangAA1', "VGlut-F-400665.CNG.swc"),
  32. ]
  33. # gridSize = 80.0
  34. # gridSize = 20.0
  35. # gridSize = 40.0
  36. gridSize = 10.0
  37. minMarkerSize = 5
  38. maxMarkerSize = 10
  39. minRad = 1
  40. maxRad = 5
  41. cols = plt.cm.rainbow(np.linspace(1, 0, len(swcFiles)))
  42. xDis = []
  43. yDis = []
  44. fig, ax = plt.subplots(figsize=(14, 11.2))
  45. for swcInd, swcFile in enumerate(swcFiles):
  46. data = np.loadtxt(swcFile)
  47. slope = (maxMarkerSize - minMarkerSize) / (maxRad - minRad)
  48. rad2MarkerSize = lambda rad: minMarkerSize + slope * (rad - minRad)
  49. xs = []
  50. ys = []
  51. rads = data[:, 5]
  52. for ind in range(1, data.shape[0]):
  53. xs.append([data[ind, 3], data[int(data[ind, 6]) - 1, 3]])
  54. ys.append([data[ind, 2], data[int(data[ind, 6]) - 1, 2]])
  55. xs = np.array(xs).T
  56. ys = np.array(ys).T
  57. xyDis = gridSize * np.array(np.around(data[:, 3:1:-1] / gridSize), np.intp)
  58. xDis += xyDis[:, 0].tolist()
  59. yDis += xyDis[:, 1].tolist()
  60. xySet = set(map(tuple, xyDis))
  61. col = cols[swcInd]
  62. lightCol = getLighterColor(col[:3], 0.5)
  63. for xy in xySet:
  64. ax.add_patch(
  65. mpatches.Rectangle((xy[0] - 0.5 * gridSize, xy[1] - 0.5 * gridSize), width=gridSize, height=gridSize,
  66. fc=lightCol))
  67. ax.plot(xs, ys, color=col, ls='-', ms=3)
  68. for x, y, r in zip(xs[0, :], ys[0, :], rads):
  69. ax.plot(x, y, color=col, marker='o', ms=rad2MarkerSize(r))
  70. ax.axis('square')
  71. xmax = max(xDis) + 0.5 * gridSize
  72. xmin = min(xDis) - 0.5 * gridSize
  73. width = xmax - xmin
  74. ymax = max(yDis) + 0.5 * gridSize
  75. ymin = min(yDis) - 0.5 * gridSize
  76. height = ymax - ymin
  77. ax.set_xlim(xmin - gridSize, xmax + gridSize)
  78. ax.set_ylim(ymin - gridSize, ymax + gridSize)
  79. xticks = np.arange(xmin, xmax + gridSize, gridSize)
  80. yticks = np.arange(ymin, ymax + gridSize, gridSize)
  81. ax.set_xticks(xticks)
  82. ax.set_yticks(yticks)
  83. # xticklabels = ['' if x % 5 else str(y) for x, y in enumerate(xticks)]
  84. # yticklabels = ['' if x % 5 else str(y) for x, y in enumerate(yticks)]
  85. xticklabels = []
  86. yticklabels = []
  87. ax.set_xticklabels(xticklabels, rotation=90)
  88. ax.set_yticklabels(yticklabels)
  89. # ax.set_xlabel(r'X in $\mu$m')
  90. # ax.set_ylabel(r'Y in $\mu$m')
  91. ax.grid(True)
  92. fig.tight_layout()
  93. fig.canvas.draw()