generate_diagram.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # -*- coding: utf-8 -*-
  2. """
  3. This generate diagram in .png and .svg from neo.core
  4. Author: sgarcia
  5. """
  6. from datetime import datetime
  7. import numpy as np
  8. import quantities as pq
  9. from matplotlib import pyplot
  10. from matplotlib.patches import Rectangle, ArrowStyle, FancyArrowPatch
  11. from matplotlib.font_manager import FontProperties
  12. from neo.test.generate_datasets import fake_neo
  13. line_heigth = .22
  14. fontsize = 10.5
  15. left_text_shift = .1
  16. dpi = 100
  17. def get_rect_height(name, obj):
  18. '''
  19. calculate rectangle height
  20. '''
  21. nlines = 1.5
  22. nlines += len(getattr(obj, '_all_attrs', []))
  23. nlines += len(getattr(obj, '_single_child_objects', []))
  24. nlines += len(getattr(obj, '_multi_child_objects', []))
  25. nlines += len(getattr(obj, '_multi_parent_objects', []))
  26. return nlines*line_heigth
  27. def annotate(ax, coord1, coord2, connectionstyle, color, alpha):
  28. arrowprops = dict(arrowstyle='fancy',
  29. #~ patchB=p,
  30. shrinkA=.3, shrinkB=.3,
  31. fc=color, ec=color,
  32. connectionstyle=connectionstyle,
  33. alpha=alpha)
  34. bbox = dict(boxstyle="square", fc="w")
  35. a = ax.annotate('', coord1, coord2,
  36. #xycoords="figure fraction",
  37. #textcoords="figure fraction",
  38. ha="right", va="center",
  39. size=fontsize,
  40. arrowprops=arrowprops,
  41. bbox=bbox)
  42. a.set_zorder(-4)
  43. def calc_coordinates(pos, height):
  44. x = pos[0]
  45. y = pos[1] + height - line_heigth*.5
  46. return pos[0], y
  47. def generate_diagram(filename, rect_pos, rect_width, figsize):
  48. rw = rect_width
  49. fig = pyplot.figure(figsize=figsize)
  50. ax = fig.add_axes([0, 0, 1, 1])
  51. all_h = {}
  52. objs = {}
  53. for name in rect_pos:
  54. objs[name] = fake_neo(name)
  55. all_h[name] = get_rect_height(name, objs[name])
  56. # draw connections
  57. color = ['c', 'm', 'y']
  58. alpha = [1., 1., 0.3]
  59. for name, pos in rect_pos.items():
  60. obj = objs[name]
  61. relationships = [getattr(obj, '_single_child_objects', []),
  62. getattr(obj, '_multi_child_objects', []),
  63. getattr(obj, '_child_properties', [])]
  64. for r in range(3):
  65. for ch_name in relationships[r]:
  66. x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name])
  67. x2, y2 = calc_coordinates(pos, all_h[name])
  68. if r in [0, 2]:
  69. x2 += rect_width
  70. connectionstyle = "arc3,rad=-0.2"
  71. elif y2 >= y1:
  72. connectionstyle = "arc3,rad=0.7"
  73. else:
  74. connectionstyle = "arc3,rad=-0.7"
  75. annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2),
  76. connectionstyle=connectionstyle,
  77. color=color[r], alpha=alpha[r])
  78. # draw boxes
  79. for name, pos in rect_pos.items():
  80. htotal = all_h[name]
  81. obj = objs[name]
  82. allrelationship = (list(getattr(obj, '_child_containers', [])) +
  83. list(getattr(obj, '_multi_parent_containers', [])))
  84. rect = Rectangle(pos, rect_width, htotal,
  85. facecolor='w', edgecolor='k', linewidth=2.)
  86. ax.add_patch(rect)
  87. # title green
  88. pos2 = pos[0], pos[1]+htotal - line_heigth*1.5
  89. rect = Rectangle(pos2, rect_width, line_heigth*1.5,
  90. facecolor='g', edgecolor='k', alpha=.5, linewidth=2.)
  91. ax.add_patch(rect)
  92. # single relationship
  93. relationship = getattr(obj, '_single_child_objects', [])
  94. pos2 = pos[1] + htotal - line_heigth*(1.5+len(relationship))
  95. rect_height = len(relationship)*line_heigth
  96. rect = Rectangle((pos[0], pos2), rect_width, rect_height,
  97. facecolor='c', edgecolor='k', alpha=.5)
  98. ax.add_patch(rect)
  99. # multi relationship
  100. relationship = (list(getattr(obj, '_multi_child_objects', [])) +
  101. list(getattr(obj, '_multi_parent_containers', [])))
  102. pos2 = (pos[1]+htotal - line_heigth*(1.5+len(relationship)) -
  103. rect_height)
  104. rect_height = len(relationship)*line_heigth
  105. rect = Rectangle((pos[0], pos2), rect_width, rect_height,
  106. facecolor='m', edgecolor='k', alpha=.5)
  107. ax.add_patch(rect)
  108. # necessary attr
  109. pos2 = (pos[1]+htotal -
  110. line_heigth*(1.5+len(allrelationship) +
  111. len(obj._necessary_attrs)))
  112. rect = Rectangle((pos[0], pos2), rect_width,
  113. line_heigth*len(obj._necessary_attrs),
  114. facecolor='r', edgecolor='k', alpha=.5)
  115. ax.add_patch(rect)
  116. # name
  117. if hasattr(obj, '_quantity_attr'):
  118. post = '* '
  119. else:
  120. post = ''
  121. ax.text(pos[0]+rect_width/2., pos[1]+htotal - line_heigth*1.5/2.,
  122. name+post,
  123. horizontalalignment='center', verticalalignment='center',
  124. fontsize=fontsize+2,
  125. fontproperties=FontProperties(weight='bold'),
  126. )
  127. #relationship
  128. for i, relat in enumerate(allrelationship):
  129. ax.text(pos[0]+left_text_shift, pos[1]+htotal - line_heigth*(i+2),
  130. relat+': list',
  131. horizontalalignment='left', verticalalignment='center',
  132. fontsize=fontsize,
  133. )
  134. # attributes
  135. for i, attr in enumerate(obj._all_attrs):
  136. attrname, attrtype = attr[0], attr[1]
  137. t1 = attrname
  138. if (hasattr(obj, '_quantity_attr') and
  139. obj._quantity_attr == attrname):
  140. t1 = attrname+'(object itself)'
  141. else:
  142. t1 = attrname
  143. if attrtype == pq.Quantity:
  144. if attr[2] == 0:
  145. t2 = 'Quantity scalar'
  146. else:
  147. t2 = 'Quantity %dD' % attr[2]
  148. elif attrtype == np.ndarray:
  149. t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind)
  150. elif attrtype == datetime:
  151. t2 = 'datetime'
  152. else:
  153. t2 = attrtype.__name__
  154. t = t1+' : '+t2
  155. ax.text(pos[0]+left_text_shift,
  156. pos[1]+htotal - line_heigth*(i+len(allrelationship)+2),
  157. t,
  158. horizontalalignment='left', verticalalignment='center',
  159. fontsize=fontsize,
  160. )
  161. xlim, ylim = figsize
  162. ax.set_xlim(0, xlim)
  163. ax.set_ylim(0, ylim)
  164. ax.set_xticks([])
  165. ax.set_yticks([])
  166. fig.savefig(filename, dpi=dpi)
  167. def generate_diagram_simple():
  168. figsize = (18, 12)
  169. rw = rect_width = 3.
  170. bf = blank_fact = 1.2
  171. rect_pos = {'Block': (.5+rw*bf*0, 4),
  172. 'Segment': (.5+rw*bf*1, .5),
  173. 'Event': (.5+rw*bf*4, 3.0),
  174. 'Epoch': (.5+rw*bf*4, 1.0),
  175. 'ChannelIndex': (.5+rw*bf*1, 7.5),
  176. 'Unit': (.5+rw*bf*2., 9.9),
  177. 'SpikeTrain': (.5+rw*bf*3, 7.5),
  178. 'IrregularlySampledSignal': (.5+rw*bf*3, 0.5),
  179. 'AnalogSignal': (.5+rw*bf*3, 4.9),
  180. }
  181. generate_diagram('simple_generated_diagram.svg',
  182. rect_pos, rect_width, figsize)
  183. generate_diagram('simple_generated_diagram.png',
  184. rect_pos, rect_width, figsize)
  185. if __name__ == '__main__':
  186. generate_diagram_simple()
  187. pyplot.show()