123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- # -*- coding: utf-8 -*-
- """
- This generate diagram in .png and .svg from neo.core
- Author: sgarcia
- """
- from datetime import datetime
- import numpy as np
- import quantities as pq
- from matplotlib import pyplot
- from matplotlib.patches import Rectangle, ArrowStyle, FancyArrowPatch
- from matplotlib.font_manager import FontProperties
- from neo.test.generate_datasets import fake_neo
- line_heigth = .22
- fontsize = 10.5
- left_text_shift = .1
- dpi = 100
- def get_rect_height(name, obj):
- '''
- calculate rectangle height
- '''
- nlines = 1.5
- nlines += len(getattr(obj, '_all_attrs', []))
- nlines += len(getattr(obj, '_single_child_objects', []))
- nlines += len(getattr(obj, '_multi_child_objects', []))
- nlines += len(getattr(obj, '_multi_parent_objects', []))
- return nlines*line_heigth
- def annotate(ax, coord1, coord2, connectionstyle, color, alpha):
- arrowprops = dict(arrowstyle='fancy',
- #~ patchB=p,
- shrinkA=.3, shrinkB=.3,
- fc=color, ec=color,
- connectionstyle=connectionstyle,
- alpha=alpha)
- bbox = dict(boxstyle="square", fc="w")
- a = ax.annotate('', coord1, coord2,
- #xycoords="figure fraction",
- #textcoords="figure fraction",
- ha="right", va="center",
- size=fontsize,
- arrowprops=arrowprops,
- bbox=bbox)
- a.set_zorder(-4)
- def calc_coordinates(pos, height):
- x = pos[0]
- y = pos[1] + height - line_heigth*.5
- return pos[0], y
- def generate_diagram(filename, rect_pos, rect_width, figsize):
- rw = rect_width
- fig = pyplot.figure(figsize=figsize)
- ax = fig.add_axes([0, 0, 1, 1])
- all_h = {}
- objs = {}
- for name in rect_pos:
- objs[name] = fake_neo(name)
- all_h[name] = get_rect_height(name, objs[name])
- # draw connections
- color = ['c', 'm', 'y']
- alpha = [1., 1., 0.3]
- for name, pos in rect_pos.items():
- obj = objs[name]
- relationships = [getattr(obj, '_single_child_objects', []),
- getattr(obj, '_multi_child_objects', []),
- getattr(obj, '_child_properties', [])]
- for r in range(3):
- for ch_name in relationships[r]:
- x1, y1 = calc_coordinates(rect_pos[ch_name], all_h[ch_name])
- x2, y2 = calc_coordinates(pos, all_h[name])
- if r in [0, 2]:
- x2 += rect_width
- connectionstyle = "arc3,rad=-0.2"
- elif y2 >= y1:
- connectionstyle = "arc3,rad=0.7"
- else:
- connectionstyle = "arc3,rad=-0.7"
- annotate(ax=ax, coord1=(x1, y1), coord2=(x2, y2),
- connectionstyle=connectionstyle,
- color=color[r], alpha=alpha[r])
- # draw boxes
- for name, pos in rect_pos.items():
- htotal = all_h[name]
- obj = objs[name]
- allrelationship = (list(getattr(obj, '_child_containers', [])) +
- list(getattr(obj, '_multi_parent_containers', [])))
- rect = Rectangle(pos, rect_width, htotal,
- facecolor='w', edgecolor='k', linewidth=2.)
- ax.add_patch(rect)
- # title green
- pos2 = pos[0], pos[1]+htotal - line_heigth*1.5
- rect = Rectangle(pos2, rect_width, line_heigth*1.5,
- facecolor='g', edgecolor='k', alpha=.5, linewidth=2.)
- ax.add_patch(rect)
- # single relationship
- relationship = getattr(obj, '_single_child_objects', [])
- pos2 = pos[1] + htotal - line_heigth*(1.5+len(relationship))
- rect_height = len(relationship)*line_heigth
- rect = Rectangle((pos[0], pos2), rect_width, rect_height,
- facecolor='c', edgecolor='k', alpha=.5)
- ax.add_patch(rect)
- # multi relationship
- relationship = (list(getattr(obj, '_multi_child_objects', [])) +
- list(getattr(obj, '_multi_parent_containers', [])))
- pos2 = (pos[1]+htotal - line_heigth*(1.5+len(relationship)) -
- rect_height)
- rect_height = len(relationship)*line_heigth
- rect = Rectangle((pos[0], pos2), rect_width, rect_height,
- facecolor='m', edgecolor='k', alpha=.5)
- ax.add_patch(rect)
- # necessary attr
- pos2 = (pos[1]+htotal -
- line_heigth*(1.5+len(allrelationship) +
- len(obj._necessary_attrs)))
- rect = Rectangle((pos[0], pos2), rect_width,
- line_heigth*len(obj._necessary_attrs),
- facecolor='r', edgecolor='k', alpha=.5)
- ax.add_patch(rect)
- # name
- if hasattr(obj, '_quantity_attr'):
- post = '* '
- else:
- post = ''
- ax.text(pos[0]+rect_width/2., pos[1]+htotal - line_heigth*1.5/2.,
- name+post,
- horizontalalignment='center', verticalalignment='center',
- fontsize=fontsize+2,
- fontproperties=FontProperties(weight='bold'),
- )
- #relationship
- for i, relat in enumerate(allrelationship):
- ax.text(pos[0]+left_text_shift, pos[1]+htotal - line_heigth*(i+2),
- relat+': list',
- horizontalalignment='left', verticalalignment='center',
- fontsize=fontsize,
- )
- # attributes
- for i, attr in enumerate(obj._all_attrs):
- attrname, attrtype = attr[0], attr[1]
- t1 = attrname
- if (hasattr(obj, '_quantity_attr') and
- obj._quantity_attr == attrname):
- t1 = attrname+'(object itself)'
- else:
- t1 = attrname
- if attrtype == pq.Quantity:
- if attr[2] == 0:
- t2 = 'Quantity scalar'
- else:
- t2 = 'Quantity %dD' % attr[2]
- elif attrtype == np.ndarray:
- t2 = "np.ndarray %dD dt='%s'" % (attr[2], attr[3].kind)
- elif attrtype == datetime:
- t2 = 'datetime'
- else:
- t2 = attrtype.__name__
- t = t1+' : '+t2
- ax.text(pos[0]+left_text_shift,
- pos[1]+htotal - line_heigth*(i+len(allrelationship)+2),
- t,
- horizontalalignment='left', verticalalignment='center',
- fontsize=fontsize,
- )
- xlim, ylim = figsize
- ax.set_xlim(0, xlim)
- ax.set_ylim(0, ylim)
- ax.set_xticks([])
- ax.set_yticks([])
- fig.savefig(filename, dpi=dpi)
- def generate_diagram_simple():
- figsize = (18, 12)
- rw = rect_width = 3.
- bf = blank_fact = 1.2
- rect_pos = {'Block': (.5+rw*bf*0, 4),
- 'Segment': (.5+rw*bf*1, .5),
- 'Event': (.5+rw*bf*4, 3.0),
- 'Epoch': (.5+rw*bf*4, 1.0),
- 'ChannelIndex': (.5+rw*bf*1, 7.5),
- 'Unit': (.5+rw*bf*2., 9.9),
- 'SpikeTrain': (.5+rw*bf*3, 7.5),
- 'IrregularlySampledSignal': (.5+rw*bf*3, 0.5),
- 'AnalogSignal': (.5+rw*bf*3, 4.9),
- }
- generate_diagram('simple_generated_diagram.svg',
- rect_pos, rect_width, figsize)
- generate_diagram('simple_generated_diagram.png',
- rect_pos, rect_width, figsize)
- if __name__ == '__main__':
- generate_diagram_simple()
- pyplot.show()
|