plotting.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # !/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import print_function, division
  4. import numpy as np
  5. import scipy.signal as sp
  6. import random
  7. import nixio as nix
  8. import matplotlib.pyplot as plt
  9. COLORS_BLUE_AND_RED = (
  10. 'dodgerblue', 'red'
  11. )
  12. COLORS_BLUE_GRADIENT = (
  13. "#034980", "#055DA1", "#1B70E0", "#3786ED", "#4A95F7",
  14. "#0C3663", "#1B4775", "#205082", "#33608F", "#51779E",
  15. "#23B0DB", "#29CDFF", "#57D8FF", "#8FE5FF"
  16. )
  17. class Plotter(object):
  18. """
  19. Plotter class for nix data arrays.
  20. """
  21. def __init__(self, width=800, height=600, dpi=90, lines=1, cols=1, facecolor="white",
  22. defaultcolors=COLORS_BLUE_GRADIENT):
  23. """
  24. :param width: Width of the image in pixels
  25. :param height: Height of the image in pixels
  26. :param dpi: DPI of the image (default 90)
  27. :param lines: Number of vertical subplots
  28. :param cols: Number of horizontal subplots
  29. :param facecolor: The background color of the plot
  30. :param defaultcolors: Defaultcolors that are assigned to lines in each subplot.
  31. """
  32. self.__width = width
  33. self.__height = height
  34. self.__dpi = dpi
  35. self.__lines = lines
  36. self.__cols = cols
  37. self.__facecolor = facecolor
  38. self.__defaultcolors = defaultcolors
  39. self.__subplot_data = tuple()
  40. for i in range(self.subplot_count):
  41. self.__subplot_data += ([], )
  42. self.__last_figure = None
  43. # properties
  44. @property
  45. def subplot_count(self):
  46. return self.__cols * self.__lines
  47. @property
  48. def subplot_data(self):
  49. return self.__subplot_data
  50. @property
  51. def defaultcolors(self):
  52. return self.__defaultcolors
  53. @property
  54. def last_figure(self):
  55. assert self.__last_figure is not None, "No figure available (method plot has to be called at least once)"
  56. return self.__last_figure
  57. # methods
  58. def save(self, name):
  59. """
  60. Saves the last figure to the specified location.
  61. :param name: The name of the figure file
  62. """
  63. self.last_figure.savefig(name)
  64. def add(self, array, subplot=0, color=None, xlim=None, downsample=None, labels=None):
  65. """
  66. Add a new data array to the plot
  67. :param array: The data array to plot
  68. :param subplot: The index of the subplot where the array should be added (starting with 0)
  69. :param color: The color of the array to plot (if None the next default colors will be assigned)
  70. :param xlim: Start and end of the x-axis limits.
  71. :param downsample: True if the array should be sampled down
  72. :param labels: Data array with labels that should be added to each data point of the array to plot
  73. """
  74. color = self.__mk_color(color, subplot)
  75. pdata = PlottingData(array, color, subplot, xlim, downsample, labels)
  76. self.subplot_data[subplot].append(pdata)
  77. def plot(self, width=None, height=None, dpi=None, lines=None, cols=None, facecolor=None):
  78. """
  79. Plots all data arrays added to the plotter.
  80. :param width: Width of the image in pixels
  81. :param height: Height of the image in pixels
  82. :param dpi: DPI of the image (default 90)
  83. :param lines: Number of vertical subplots
  84. :param cols: Number of horizontal subplots
  85. :param facecolor: The background color of the plot
  86. """
  87. # defaults
  88. width = width or self.__width
  89. height = height or self.__height
  90. dpi = dpi or self.__dpi
  91. lines = lines or self.__lines
  92. cols = cols or self.__cols
  93. facecolor = facecolor or self.__facecolor
  94. # plot
  95. figure, axis_all = plot_make_figure(width, height, dpi, cols, lines, facecolor)
  96. for subplot, pdata_list in enumerate(self.subplot_data):
  97. axis = axis_all[subplot]
  98. pdata_list.sort()
  99. event_like = Plotter.__count_event_like(pdata_list)
  100. signal_like = Plotter.__count_signal_like(pdata_list)
  101. for i, pdata in enumerate(pdata_list):
  102. d1type = pdata.array.dimensions[0].dimension_type
  103. shape = pdata.array.shape
  104. nd = len(shape)
  105. if nd == 1:
  106. if d1type == nix.DimensionType.Set:
  107. second_y = signal_like > 0
  108. hint = (i + 1.0) / (event_like + 1.0) if event_like > 0 else None
  109. plot_array_1d_set(pdata.array, axis, color=pdata.color, xlim=pdata.xlim, labels=pdata.labels,
  110. second_y=second_y, hint=hint)
  111. else:
  112. plot_array_1d(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
  113. downsample=pdata.downsample)
  114. elif nd == 2:
  115. if d1type == nix.DimensionType.Set:
  116. plot_array_2d_set(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
  117. downsample=pdata.downsample)
  118. else:
  119. plot_array_2d(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
  120. downsample=pdata.downsample)
  121. else:
  122. raise Exception('Unsupported data')
  123. axis.legend()
  124. self.__last_figure = figure
  125. # private methods
  126. def __mk_color(self, color, subplot):
  127. """
  128. If color is None, select one from the defaults or create a random color.
  129. """
  130. if color is None:
  131. color_count = len(self.defaultcolors)
  132. count = len(self.subplot_data[subplot])
  133. color = self.defaultcolors[count if count < color_count else color_count - 1]
  134. if color == "random":
  135. color = "#%02x%02x%02x" % (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
  136. return color
  137. @staticmethod
  138. def __count_signal_like(pdata_list):
  139. sig_types = (nix.DimensionType.Range, nix.DimensionType.Sample)
  140. count = 0
  141. for pdata in pdata_list:
  142. dims = pdata.array.dimensions
  143. nd = len(dims)
  144. if nd == 1 and dims[0].dimension_type in sig_types:
  145. count += 1
  146. elif nd == 2 and dims[0].dimension_type == nix.DimensionType.Set and dims[1].dimension_type in sig_types:
  147. count += 1
  148. return count
  149. @staticmethod
  150. def __count_image_like(pdata_list):
  151. sig_types = (nix.DimensionType.Range, nix.DimensionType.Sample)
  152. count = 0
  153. for pdata in pdata_list:
  154. dims = pdata.array.dimensions
  155. nd = len(dims)
  156. if nd == 2 and dims[0].dimension_type in sig_types and dims[1].dimension_type in sig_types:
  157. count += 1
  158. return count
  159. @staticmethod
  160. def __count_event_like(pdata_list):
  161. count = 0
  162. for pdata in pdata_list:
  163. dims = pdata.array.dimensions
  164. nd = len(dims)
  165. if dims[0].dimension_type == nix.DimensionType.Set:
  166. count += 1
  167. return count
  168. class PlottingData(object):
  169. def __init__(self, array, color, subplot=0, xlim=None, downsample=False, labels=None):
  170. self.array = array
  171. self.dimensions = array.dimensions
  172. self.shape = array.shape
  173. self.rank = len(array.shape)
  174. self.color = color
  175. self.subplot = subplot
  176. self.xlim = xlim
  177. self.downsample = downsample
  178. self.labels = labels
  179. def __cmp__(self, other):
  180. weights = lambda dims: [(1 if d.dimension_type == nix.DimensionType.Sample else 0) for d in dims]
  181. return cmp(weights(self.array.dimensions), weights(other.array.dimensions))
  182. def __lt__(self, other):
  183. return self.__cmp__(other) < 0
  184. def plot_make_figure(width, height, dpi, cols, lines, facecolor):
  185. axis_all = []
  186. figure = plt.figure(facecolor=facecolor, figsize=(width / dpi, height / dpi), dpi=90)
  187. figure.subplots_adjust(wspace=0.3, hspace=0.3, left=0.1, right=0.9, bottom=0.05, top=0.95)
  188. for subplot in range(cols * lines):
  189. axis = figure.add_subplot(lines, cols, subplot+1)
  190. axis.tick_params(direction='out')
  191. axis.spines['top'].set_color('none')
  192. axis.spines['right'].set_color('none')
  193. axis.xaxis.set_ticks_position('bottom')
  194. axis.yaxis.set_ticks_position('left')
  195. axis_all.append(axis)
  196. return figure, axis_all
  197. def plot_array_1d(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
  198. dim = array.dimensions[0]
  199. assert dim.dimension_type in (nix.DimensionType.Sample, nix.DimensionType.Range), "Unsupported data"
  200. y = array[:]
  201. if dim.dimension_type == nix.DimensionType.Sample:
  202. x_start = dim.offset or 0
  203. x = np.arange(0, array.shape[0]) * dim.sampling_interval + x_start
  204. else:
  205. x = np.array(dim.ticks)
  206. if downsample is not None:
  207. x = sp.decimate(x, downsample)
  208. y = sp.decimate(y, downsample)
  209. if xlim is not None:
  210. y = y[(x >= xlim[0]) & (x <= xlim[1])]
  211. x = x[(x >= xlim[0]) & (x <= xlim[1])]
  212. axis.plot(x, y, color, label=array.name)
  213. axis.set_xlabel('%s [%s]' % (dim.label, dim.unit))
  214. axis.set_ylabel('%s [%s]' % (array.label, array.unit))
  215. axis.set_xlim([np.min(x), np.max(x)])
  216. def plot_array_1d_set(array, axis, color=None, xlim=None, hint=None, labels=None, second_y=False):
  217. dim = array.dimensions[0]
  218. assert dim.dimension_type == nix.DimensionType.Set, "Unsupported data"
  219. x = array[:]
  220. z = np.ones_like(x) * 0.8 * (hint or 0.5) + 0.1
  221. if second_y:
  222. ax2 = axis.twinx()
  223. ax2.set_ylim([0, 1])
  224. ax2.scatter(x, z, 50, color, linewidths=2, label=array.name, marker="|")
  225. ax2.set_yticks([])
  226. if labels is not None:
  227. for i, v in enumerate(labels[:]):
  228. ax2.annotate(str(v), (x[i], z[i]))
  229. else:
  230. #x = array[xlim or Ellipsis]
  231. axis.set_ylim([0, 1])
  232. axis.scatter(x, z, 50, color, linewidths=2, label=array.name, marker="|")
  233. axis.set_xlabel('%s [%s]' % (array.label, array.unit))
  234. axis.set_ylabel(array.name)
  235. axis.set_yticks([])
  236. if labels is not None:
  237. for i, v in enumerate(labels[:]):
  238. axis.annotate(str(v), (x[i], z[i]))
  239. def plot_array_2d(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
  240. d1 = array.dimensions[0]
  241. d2 = array.dimensions[1]
  242. d1_type = d1.dimension_type
  243. d2_type = d2.dimension_type
  244. assert d1_type == nix.DimensionType.Sample, "Unsupported data"
  245. assert d2_type == nix.DimensionType.Sample, "Unsupported data"
  246. z = array[:]
  247. x_start = d1.offset or 0
  248. y_start = d2.offset or 0
  249. x_end = x_start + array.shape[0] * d1.sampling_interval
  250. y_end = y_start + array.shape[1] * d2.sampling_interval
  251. axis.imshow(z, origin='lower', extent=[x_start, x_end, y_start, y_end])
  252. axis.set_xlabel('%s [%s]' % (d1.label, d1.unit))
  253. axis.set_ylabel('%s [%s]' % (d2.label, d2.unit))
  254. axis.set_title(array.name)
  255. bar = plt.colorbar()
  256. bar.label('%s [%s]' % (array.label, array.unit))
  257. def plot_array_2d_set(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
  258. d1 = array.dimensions[0]
  259. d2 = array.dimensions[1]
  260. d1_type = d1.dimension_type
  261. d2_type = d2.dimension_type
  262. assert d1_type == nix.DimensionType.Set, "Unsupported data"
  263. assert d2_type == nix.DimensionType.Sample, "Unsupported data"
  264. x_start = d2.offset or 0
  265. x_one = x_start + np.arange(0, array.shape[1]) * d2.sampling_interval
  266. x = np.tile(x_one.reshape(array.shape[1], 1), array.shape[0])
  267. y = array[:]
  268. axis.plot(x, y.T, color=color)
  269. axis.set_title(array.name)
  270. axis.set_xlabel('%s [%s]' % (d2.label, d2.unit))
  271. axis.set_ylabel('%s [%s]' % (array.label, array.unit))
  272. if d1.labels is not None:
  273. axis.legend(d1.labels)