generate_diagram.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. This generate diagram in .png and .svg from neo.core
  3. Author: sgarcia
  4. """
  5. from datetime import datetime
  6. import numpy as np
  7. import quantities as pq
  8. from matplotlib import pyplot
  9. from matplotlib.patches import Rectangle, ArrowStyle, FancyArrowPatch
  10. from matplotlib.font_manager import FontProperties
  11. from neo.test.generate_datasets import fake_neo
  12. line_heigth = .22
  13. fontsize = 10.5
  14. left_text_shift = .1
  15. dpi = 100
  16. def get_rect_height(name, obj):
  17. '''
  18. calculate rectangle height
  19. '''
  20. nlines = 1.5
  21. nlines += len(getattr(obj, '_all_attrs', []))
  22. nlines += len(getattr(obj, '_single_child_objects', []))
  23. nlines += len(getattr(obj, '_multi_child_objects', []))
  24. nlines += len(getattr(obj, '_multi_parent_objects', []))
  25. return nlines * line_heigth
  26. def annotate(ax, coord1, coord2, connectionstyle, color, alpha):
  27. arrowprops = dict(arrowstyle='fancy',
  28. # ~ patchB=p,
  29. shrinkA=.3, shrinkB=.3,
  30. fc=color, ec=color,
  31. connectionstyle=connectionstyle,
  32. alpha=alpha)
  33. bbox = dict(boxstyle="square", fc="w")
  34. a = ax.annotate('', coord1, coord2,
  35. # xycoords="figure fraction",
  36. # textcoords="figure fraction",
  37. ha="right", va="center",
  38. size=fontsize,
  39. arrowprops=arrowprops,
  40. bbox=bbox)
  41. a.set_zorder(-4)
  42. def calc_coordinates(pos, height):
  43. x = pos[0]
  44. y = pos[1] + height - line_heigth * .5
  45. return pos[0], y
  46. def generate_diagram(filename, rect_pos, rect_width, figsize):
  47. rw = rect_width
  48. fig = pyplot.figure(figsize=figsize)
  49. ax = fig.add_axes([0, 0, 1, 1])
  50. all_h = {}
  51. objs = {}
  52. for name in rect_pos:
  53. objs[name] = fake_neo(name)
  54. all_h[name] = get_rect_height(name, objs[name])
  55. # draw connections
  56. color = ['c', 'm', 'y']
  57. alpha = [1., 1., 0.3]
  58. for name, pos in rect_pos.items():
  59. obj = objs[name]
  60. relationships = [getattr(obj, '_single_child_objects', []),
  61. getattr(obj, '_multi_child_objects', []),
  62. getattr(obj, '_child_properties', [])]
  63. for r in range(3):
  64. for ch_name in relationships[r]:
  65. x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name])
  66. x2, y2 = calc_coordinates(pos, all_h[name])
  67. if r in [0, 2]:
  68. x2 += rect_width
  69. connectionstyle = "arc3,rad=-0.2"
  70. elif y2 >= y1:
  71. connectionstyle = "arc3,rad=0.7"
  72. else:
  73. connectionstyle = "arc3,rad=-0.7"
  74. annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2),
  75. connectionstyle=connectionstyle,
  76. color=color[r], alpha=alpha[r])
  77. # draw boxes
  78. for name, pos in rect_pos.items():
  79. htotal = all_h[name]
  80. obj = objs[name]
  81. allrelationship = (list(getattr(obj, '_child_containers', []))
  82. + list(getattr(obj, '_multi_parent_containers', [])))
  83. rect = Rectangle(pos, rect_width, htotal,
  84. facecolor='w', edgecolor='k', linewidth=2.)
  85. ax.add_patch(rect)
  86. # title green
  87. pos2 = pos[0], pos[1] + htotal - line_heigth * 1.5
  88. rect = Rectangle(pos2, rect_width, line_heigth * 1.5,
  89. facecolor='g', edgecolor='k', alpha=.5, linewidth=2.)
  90. ax.add_patch(rect)
  91. # single relationship
  92. relationship = getattr(obj, '_single_child_objects', [])
  93. pos2 = pos[1] + htotal - line_heigth * (1.5 + len(relationship))
  94. rect_height = len(relationship) * line_heigth
  95. rect = Rectangle((pos[0], pos2), rect_width, rect_height,
  96. facecolor='c', edgecolor='k', alpha=.5)
  97. ax.add_patch(rect)
  98. # multi relationship
  99. relationship = (list(getattr(obj, '_multi_child_objects', []))
  100. + list(getattr(obj, '_multi_parent_containers', [])))
  101. pos2 = (pos[1] + htotal - line_heigth * (1.5 + len(relationship))
  102. - rect_height)
  103. rect_height = len(relationship) * line_heigth
  104. rect = Rectangle((pos[0], pos2), rect_width, rect_height,
  105. facecolor='m', edgecolor='k', alpha=.5)
  106. ax.add_patch(rect)
  107. # necessary attr
  108. pos2 = (pos[1] + htotal
  109. - line_heigth * (1.5 + len(allrelationship) + len(obj._necessary_attrs)))
  110. rect = Rectangle((pos[0], pos2), rect_width,
  111. line_heigth * len(obj._necessary_attrs),
  112. facecolor='r', edgecolor='k', alpha=.5)
  113. ax.add_patch(rect)
  114. # name
  115. if hasattr(obj, '_quantity_attr'):
  116. post = '* '
  117. else:
  118. post = ''
  119. ax.text(pos[0] + rect_width / 2., pos[1] + htotal - line_heigth * 1.5 / 2.,
  120. name + post,
  121. horizontalalignment='center', verticalalignment='center',
  122. fontsize=fontsize + 2,
  123. fontproperties=FontProperties(weight='bold'),
  124. )
  125. # relationship
  126. for i, relat in enumerate(allrelationship):
  127. ax.text(pos[0] + left_text_shift, pos[1] + htotal - line_heigth * (i + 2),
  128. relat + ': list',
  129. horizontalalignment='left', verticalalignment='center',
  130. fontsize=fontsize,
  131. )
  132. # attributes
  133. for i, attr in enumerate(obj._all_attrs):
  134. attrname, attrtype = attr[0], attr[1]
  135. t1 = attrname
  136. if (hasattr(obj, '_quantity_attr')
  137. and obj._quantity_attr == attrname):
  138. t1 = attrname + '(object itself)'
  139. else:
  140. t1 = attrname
  141. if attrtype == pq.Quantity:
  142. if attr[2] == 0:
  143. t2 = 'Quantity scalar'
  144. else:
  145. t2 = 'Quantity %dD' % attr[2]
  146. elif attrtype == np.ndarray:
  147. t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind)
  148. elif attrtype == datetime:
  149. t2 = 'datetime'
  150. else:
  151. t2 = attrtype.__name__
  152. t = t1 + ' : ' + t2
  153. ax.text(pos[0] + left_text_shift,
  154. pos[1] + htotal - line_heigth * (i + len(allrelationship) + 2),
  155. t,
  156. horizontalalignment='left', verticalalignment='center',
  157. fontsize=fontsize,
  158. )
  159. xlim, ylim = figsize
  160. ax.set_xlim(0, xlim)
  161. ax.set_ylim(0, ylim)
  162. ax.set_xticks([])
  163. ax.set_yticks([])
  164. fig.savefig(filename, dpi=dpi)
  165. def generate_diagram_simple():
  166. figsize = (18, 12)
  167. rw = rect_width = 3.
  168. bf = blank_fact = 1.2
  169. rect_pos = {'Block': (.5 + rw * bf * 0, 4),
  170. 'Segment': (.5 + rw * bf * 1, .5),
  171. 'Event': (.5 + rw * bf * 4, 3.0),
  172. 'Epoch': (.5 + rw * bf * 4, 1.0),
  173. 'Group': (.5 + rw * bf * 1, 7.5),
  174. 'ChannelView': (.5 + rw * bf * 2., 9.9),
  175. 'SpikeTrain': (.5 + rw * bf * 3, 7.5),
  176. 'IrregularlySampledSignal': (.5 + rw * bf * 3, 0.5),
  177. 'AnalogSignal': (.5 + rw * bf * 3, 4.9),
  178. }
  179. # todo: add ImageSequence, RegionOfInterest
  180. generate_diagram('simple_generated_diagram.svg',
  181. rect_pos, rect_width, figsize)
  182. generate_diagram('simple_generated_diagram.png',
  183. rect_pos, rect_width, figsize)
  184. if __name__ == '__main__':
  185. generate_diagram_simple()
  186. pyplot.show()