Browse Source

Upload files to ''

Reema Gupta 2 years ago
parent
commit
fe0a537060

BIN
utils/__init__.py


BIN
utils/__pycache__/__init__.cpython-39.pyc


BIN
utils/__pycache__/notebook.cpython-39.pyc


BIN
utils/__pycache__/plotting.cpython-39.pyc


+ 99 - 0
utils/lif.py

@@ -0,0 +1,99 @@
+# -*- coding: utf-8 -*-
+"""
+ Copyright © 2014 German Neuroinformatics Node (G-Node)
+
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted under the terms of the BSD License. See
+ LICENSE file in the root of the Project.
+
+ Author: Jan Grewe <jan.grewe@g-node.org>
+"""
+import numpy as np
+
+class lif:
+    
+    def __init__(self, stepsize=0.0001, offset=1.6, tau_m=0.025, tau_a=0.02, da=0.0, D=3.5):
+        self.stepsize = stepsize # simulation stepsize [s]
+        self.offset = offset # offset curent [nA]
+        self.tau_m = tau_m # membrane time_constant [s]
+        self.tau_a = tau_a # adaptation time_constant [s]
+        self.da = da # increment in adaptation current [nA]
+        self.D = D # noise intensity
+        self.v_threshold = 1.0 # spiking threshold
+        self.v_reset = 0.0 # reset voltage after spiking
+        self.i_a = 0.0 # current adaptation current 
+        self.v = self.v_reset # current membrane voltage
+        self.t = 0.0 # current time [s] 
+        self.membrane_voltage = []
+        self.spike_times = []
+
+
+    def _reset(self):
+        self.i_a = 0.0
+        self.v = self.v_reset
+        self.t = 0.0
+        self.membrane_voltage = []
+        self.spike_times = []
+    
+
+    def _lif(self, stimulus, noise):
+        """
+        euler solution of the membrane equation with adaptation current and noise
+        """
+        self.i_a -= self.i_a - self.stepsize/self.tau_a * (self.i_a)
+        self.v += self.stepsize * ( -self.v + stimulus + noise + self.offset - self.i_a)/self.tau_m; 
+        self.membrane_voltage.append(self.v)
+
+
+    def _next(self, stimulus):
+        """
+        working horse which delegates to the euler and gets the spike times
+        """
+        noise = self.D * (float(np.random.randn() % 10000) - 5000.0)/10000
+        self._lif(stimulus, noise)
+        self.t += self.stepsize
+        if self.v > self.v_threshold and len(self.membrane_voltage) > 1:
+            self.v = self.v_reset
+            self.membrane_voltage[len(self.membrane_voltage)-1] = 2.0
+            self.spike_times.append(self.t)
+            self.i_a += self.da;
+  
+    
+    def run_const_stim(self, steps, stimulus):
+        """
+        lif simulation with constant stimulus.
+        """
+        self._reset()
+        for i in range(steps):
+            self._next(stimulus);
+        time = np.arange(len(self.membrane_voltage))*self.stepsize
+        return time, np.array(self.membrane_voltage), np.array(self.spike_times)
+
+
+    def run_stimulus(self, stimulus):
+        """
+        lif simulation with a predefined stimulus trace.
+        """
+        self._reset()
+        for s in stimulus:
+            self._next(s);
+        time = np.arange(len(self.membrane_voltage))*self.stepsize
+        return time, np.array(self.membrane_voltage), np.array(self.spike_times)
+
+
+    def __str__(self):
+        out = '\n'.join(["stepsize: \t" + str(self.stepsize),
+                         "offset:\t\t" + str(self.offset),
+                         "tau_m:\t\t" + str(self.tau_m),
+                         "tau_a:\t\t" + str(self.tau_a),
+                         "da:\t\t" + str(self.da),
+                         "D:\t\t" + str(self.D),
+                         "v_threshold:\t" + str(self.v_threshold),
+                         "v_reset:\t" + str(self.v_reset)])
+        return out
+
+
+    def __repr__(self):
+        return self.__str__()

+ 44 - 0
utils/notebook.py

@@ -0,0 +1,44 @@
+__author__ = 'andrey'
+
+
+def print_stats(items):
+    if items is None or len(items) < 1:
+        return
+
+    print("\n%-50s (%02d)" % (items[0].__class__.__name__ + "s", len(items)))
+    for t in set(i.type for i in items):
+        print("\ttype: %-35s  (%02d)" % (t, len([1 for i in items if i.type == t])))
+
+
+def print_metadata_table(section):
+    import matplotlib.pyplot as plt
+    columns = ['Name', 'Value', 'Unit']
+    cell_text = []
+    for p in [(i.name, i) for i in section.props]:
+        
+        for i, v in enumerate(p[1].values):
+            value = str(v.value)
+            if len(value) > 30:
+                value = value[:30] + '...'
+            if i == 0:
+                row_data = [p[0], value, p[1].unit if p[1].unit else '-']
+            else:
+                row_data = [p[0], value, p[1].unit if p[1].unit else '-']
+
+            cell_text.append(row_data)
+    if len(cell_text) > 0:
+        nrows, ncols = len(cell_text)+1, len(columns)
+        hcell, wcell = 1., 5.
+        hpad, wpad = 0.5, 0    
+        fig = plt.figure(figsize=(ncols*wcell+wpad, nrows*hcell+hpad))
+        ax = fig.add_subplot(111)
+        ax.axis('off')
+        the_table = ax.table(cellText=cell_text,
+                               colLabels=columns, 
+                               loc='center')
+        for cell in the_table.get_children():
+            cell.set_height(.075)
+            cell.set_fontsize(12)
+                        
+    #ax.set_title(section.name, fontsize=12)
+    return fig

+ 357 - 0
utils/plotting.py

@@ -0,0 +1,357 @@
+# !/usr/bin/env python
+#  -*- coding: utf-8 -*-
+from __future__ import print_function, division
+
+import numpy as np
+import scipy.signal as sp
+import random
+
+import nixio as nix
+import matplotlib.pyplot as plt
+
+COLORS_BLUE_AND_RED = (
+    'dodgerblue', 'red'
+)
+
+COLORS_BLUE_GRADIENT = (
+    "#034980", "#055DA1", "#1B70E0", "#3786ED", "#4A95F7",
+    "#0C3663", "#1B4775", "#205082", "#33608F", "#51779E",
+    "#23B0DB", "#29CDFF", "#57D8FF", "#8FE5FF"
+)
+
+
+class Plotter(object):
+    """
+    Plotter class for nix data arrays.
+    """
+
+    def __init__(self, width=800, height=600, dpi=90, lines=1, cols=1, facecolor="white",
+                 defaultcolors=COLORS_BLUE_GRADIENT):
+        """
+
+
+        :param width:       Width of the image in pixels
+        :param height:      Height of the image in pixels
+        :param dpi:         DPI of the image (default 90)
+        :param lines:       Number of vertical subplots
+        :param cols:        Number of horizontal subplots
+        :param facecolor:   The background color of the plot
+        :param defaultcolors: Defaultcolors that are assigned to lines in each subplot.
+        """
+
+        self.__width = width
+        self.__height = height
+        self.__dpi = dpi
+        self.__lines = lines
+        self.__cols = cols
+        self.__facecolor = facecolor
+        self.__defaultcolors = defaultcolors
+
+        self.__subplot_data = tuple()
+        for i in range(self.subplot_count):
+            self.__subplot_data += ([], )
+
+        self.__last_figure = None
+
+    # properties
+
+    @property
+    def subplot_count(self):
+        return self.__cols * self.__lines
+
+    @property
+    def subplot_data(self):
+        return self.__subplot_data
+
+    @property
+    def defaultcolors(self):
+        return self.__defaultcolors
+
+    @property
+    def last_figure(self):
+        assert self.__last_figure is not None, "No figure available (method plot has to be called at least once)"
+        return self.__last_figure
+
+    # methods
+
+    def save(self, name):
+        """
+        Saves the last figure to the specified location.
+
+        :param name:    The name of the figure file
+        """
+        self.last_figure.savefig(name)
+
+    def add(self, array, subplot=0, color=None, xlim=None, downsample=None, labels=None):
+        """
+        Add a new data array to the plot
+
+        :param array:       The data array to plot
+        :param subplot:     The index of the subplot where the array should be added (starting with 0)
+        :param color:       The color of the array to plot (if None the next default colors will be assigned)
+        :param xlim:        Start and end of the x-axis limits.
+        :param downsample:  True if the array should be sampled down
+        :param labels:      Data array with labels that should be added to each data point of the array to plot
+        """
+        color = self.__mk_color(color, subplot)
+        pdata = PlottingData(array, color, subplot, xlim, downsample, labels)
+        self.subplot_data[subplot].append(pdata)
+
+    def plot(self, width=None, height=None, dpi=None, lines=None, cols=None, facecolor=None):
+        """
+        Plots all data arrays added to the plotter.
+
+        :param width:       Width of the image in pixels
+        :param height:      Height of the image in pixels
+        :param dpi:         DPI of the image (default 90)
+        :param lines:       Number of vertical subplots
+        :param cols:        Number of horizontal subplots
+        :param facecolor:   The background color of the plot
+        """
+        # defaults
+        width = width or self.__width
+        height = height or self.__height
+        dpi = dpi or self.__dpi
+        lines = lines or self.__lines
+        cols = cols or self.__cols
+        facecolor = facecolor or self.__facecolor
+
+        # plot
+        figure, axis_all = plot_make_figure(width, height, dpi, cols, lines, facecolor)
+
+        for subplot, pdata_list in enumerate(self.subplot_data):
+            axis = axis_all[subplot]
+            pdata_list.sort()
+
+            event_like = Plotter.__count_event_like(pdata_list)
+            signal_like = Plotter.__count_signal_like(pdata_list)
+
+            for i, pdata in enumerate(pdata_list):
+                d1type = pdata.array.dimensions[0].dimension_type
+                shape = pdata.array.shape
+                nd = len(shape)
+
+                if nd == 1:
+                    if d1type == nix.DimensionType.Set:
+                        second_y = signal_like > 0
+                        hint = (i + 1.0) / (event_like + 1.0) if event_like > 0 else None
+                        plot_array_1d_set(pdata.array, axis, color=pdata.color, xlim=pdata.xlim, labels=pdata.labels,
+                                          second_y=second_y, hint=hint)
+                    else:
+                        plot_array_1d(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
+                                      downsample=pdata.downsample)
+                elif nd == 2:
+                    if d1type == nix.DimensionType.Set:
+                        plot_array_2d_set(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
+                                          downsample=pdata.downsample)
+                    else:
+                        plot_array_2d(pdata.array, axis, color=pdata.color, xlim=pdata.xlim,
+                                      downsample=pdata.downsample)
+                else:
+                    raise Exception('Unsupported data')
+
+            axis.legend()
+
+        self.__last_figure = figure
+
+    # private methods
+
+    def __mk_color(self, color, subplot):
+        """
+        If color is None, select one from the defaults or create a random color.
+        """
+        if color is None:
+            color_count = len(self.defaultcolors)
+            count = len(self.subplot_data[subplot])
+            color = self.defaultcolors[count if count < color_count else color_count - 1]
+
+        if color == "random":
+            color = "#%02x%02x%02x" % (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
+
+        return color
+
+    @staticmethod
+    def __count_signal_like(pdata_list):
+        sig_types = (nix.DimensionType.Range, nix.DimensionType.Sample)
+        count = 0
+
+        for pdata in pdata_list:
+            dims = pdata.array.dimensions
+            nd = len(dims)
+
+            if nd == 1 and dims[0].dimension_type in sig_types:
+                count += 1
+            elif nd == 2 and dims[0].dimension_type == nix.DimensionType.Set and dims[1].dimension_type in sig_types:
+                count += 1
+
+        return count
+
+    @staticmethod
+    def __count_image_like(pdata_list):
+        sig_types = (nix.DimensionType.Range, nix.DimensionType.Sample)
+        count = 0
+
+        for pdata in pdata_list:
+            dims = pdata.array.dimensions
+            nd = len(dims)
+
+            if nd == 2 and dims[0].dimension_type in sig_types and dims[1].dimension_type in sig_types:
+                count += 1
+
+        return count
+
+    @staticmethod
+    def __count_event_like(pdata_list):
+        count = 0
+
+        for pdata in pdata_list:
+            dims = pdata.array.dimensions
+            nd = len(dims)
+
+            if dims[0].dimension_type == nix.DimensionType.Set:
+                count += 1
+
+        return count
+
+
+class PlottingData(object):
+
+    def __init__(self, array, color, subplot=0, xlim=None, downsample=False, labels=None):
+        self.array = array
+        self.dimensions = array.dimensions
+        self.shape = array.shape
+        self.rank = len(array.shape)
+        self.color = color
+        self.subplot = subplot
+        self.xlim = xlim
+        self.downsample = downsample
+        self.labels = labels
+
+    def __cmp__(self, other):
+        weights = lambda dims: [(1 if d.dimension_type == nix.DimensionType.Sample else 0) for d in dims]
+        return cmp(weights(self.array.dimensions), weights(other.array.dimensions))
+
+    def __lt__(self, other):
+        return self.__cmp__(other) < 0
+
+
+def plot_make_figure(width, height, dpi, cols, lines, facecolor):
+    axis_all = []
+    figure = plt.figure(facecolor=facecolor, figsize=(width / dpi, height / dpi), dpi=90)
+    figure.subplots_adjust(wspace=0.3, hspace=0.3, left=0.1, right=0.9, bottom=0.05, top=0.95)
+
+    for subplot in range(cols * lines):
+
+        axis = figure.add_subplot(lines, cols, subplot+1)
+        axis.tick_params(direction='out')
+        axis.spines['top'].set_color('none')
+        axis.spines['right'].set_color('none')
+        axis.xaxis.set_ticks_position('bottom')
+        axis.yaxis.set_ticks_position('left')
+
+        axis_all.append(axis)
+
+    return figure, axis_all
+
+
+def plot_array_1d(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
+    dim = array.dimensions[0]
+
+    assert dim.dimension_type in (nix.DimensionType.Sample, nix.DimensionType.Range), "Unsupported data"
+
+    y = array[:]
+    if dim.dimension_type == nix.DimensionType.Sample:
+        x_start = dim.offset or 0
+        x = np.arange(0, array.shape[0]) * dim.sampling_interval + x_start
+    else:
+        x = np.array(dim.ticks)
+    
+    if downsample is not None:
+        x = sp.decimate(x, downsample)
+        y = sp.decimate(y, downsample)
+    if xlim is not None:
+        y = y[(x >= xlim[0]) & (x <= xlim[1])]
+        x = x[(x >= xlim[0]) & (x <= xlim[1])]
+       
+    axis.plot(x, y, color, label=array.name)
+    axis.set_xlabel('%s [%s]' % (dim.label, dim.unit))
+    axis.set_ylabel('%s [%s]' % (array.label, array.unit))
+    axis.set_xlim([np.min(x), np.max(x)])
+
+
+def plot_array_1d_set(array, axis, color=None, xlim=None, hint=None, labels=None, second_y=False):
+    dim = array.dimensions[0]
+
+    assert dim.dimension_type == nix.DimensionType.Set, "Unsupported data"
+
+    x = array[:]
+    z = np.ones_like(x) * 0.8 * (hint or 0.5) + 0.1
+    if second_y:
+        ax2 = axis.twinx()
+        ax2.set_ylim([0, 1])
+        ax2.scatter(x, z, 50, color, linewidths=2, label=array.name, marker="|")
+        ax2.set_yticks([])
+
+        if labels is not None:
+            for i, v in enumerate(labels[:]):
+                ax2.annotate(str(v), (x[i], z[i]))
+
+    else:
+        #x = array[xlim or Ellipsis]
+        axis.set_ylim([0, 1])
+        axis.scatter(x, z, 50, color, linewidths=2, label=array.name, marker="|")
+        axis.set_xlabel('%s [%s]' % (array.label, array.unit))
+        axis.set_ylabel(array.name)
+        axis.set_yticks([])
+
+        if labels is not None:
+            for i, v in enumerate(labels[:]):
+                axis.annotate(str(v), (x[i], z[i]))
+
+
+def plot_array_2d(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
+    d1 = array.dimensions[0]
+    d2 = array.dimensions[1]
+
+    d1_type = d1.dimension_type
+    d2_type = d2.dimension_type
+
+    assert d1_type == nix.DimensionType.Sample, "Unsupported data"
+    assert d2_type == nix.DimensionType.Sample, "Unsupported data"
+
+    z = array[:]
+    x_start = d1.offset or 0
+    y_start = d2.offset or 0
+    x_end = x_start + array.shape[0] * d1.sampling_interval
+    y_end = y_start + array.shape[1] * d2.sampling_interval
+
+    axis.imshow(z, origin='lower', extent=[x_start, x_end, y_start, y_end])
+    axis.set_xlabel('%s [%s]' % (d1.label, d1.unit))
+    axis.set_ylabel('%s [%s]' % (d2.label, d2.unit))
+    axis.set_title(array.name)
+    bar = plt.colorbar()
+    bar.label('%s [%s]' % (array.label, array.unit))
+
+
+def plot_array_2d_set(array, axis, color=None, xlim=None, downsample=None, hint=None, labels=None):
+    d1 = array.dimensions[0]
+    d2 = array.dimensions[1]
+
+    d1_type = d1.dimension_type
+    d2_type = d2.dimension_type
+
+    assert d1_type == nix.DimensionType.Set, "Unsupported data"
+    assert d2_type == nix.DimensionType.Sample, "Unsupported data"
+
+    x_start = d2.offset or 0
+    x_one = x_start + np.arange(0, array.shape[1]) * d2.sampling_interval
+    x = np.tile(x_one.reshape(array.shape[1], 1), array.shape[0])
+    y = array[:]
+    axis.plot(x, y.T, color=color)
+    axis.set_title(array.name)
+    axis.set_xlabel('%s [%s]' % (d2.label, d2.unit))
+    axis.set_ylabel('%s [%s]' % (array.label, array.unit))
+
+    if d1.labels is not None:
+        axis.legend(d1.labels)
+

+ 98 - 0
utils/video_player.py

@@ -0,0 +1,98 @@
+# !/usr/bin/env python
+#  -*- coding: utf-8 -*-
+from __future__ import print_function, division
+
+import nixio as nix
+import math
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+import matplotlib
+matplotlib.use('TkAgg')
+
+class Playback(object):
+    
+    def __init__(self, fig, video_array, tracking_tag=None, show_orientation=False):
+        self.figure = fig
+        self.axis = fig.add_subplot(111)
+        self.im = None
+
+        self.data = video_array
+        self.height, self.width, self.channels, self.nframes = self.data.shape
+        dim = video_array.dimensions[-1]
+        ticks = dim.ticks
+        self.interval = np.mean(np.diff(ticks))
+        
+        self.tag = tracking_tag
+        if self.tag is not None:
+            self.positions = self.tag.positions
+            self.orientations = self.tag.features[0].data
+            self.tracked_indices = self.__track_indices(ticks, self.positions[:,3])
+            self.x = self.positions[:,0]
+            self.y = self.positions[:,1]
+            self.track_counter = 0
+            self.draw_orientation = show_orientation
+    
+    def __track_indices(self, ticks, times):
+        indices = np.zeros_like(times)
+        for i,t in enumerate(times):
+            indices[i] = np.argmin(np.abs(np.asarray(ticks) - t*1000))   
+        return indices
+
+    def __draw_circ(self, frame, x_pos, y_pos):
+        radius = 8
+        y, x = np.ogrid[-radius: radius, -radius: radius]
+        index = x**2 + y**2 <= radius**2
+        frame[y_pos-radius:y_pos+radius, x_pos-radius:x_pos+radius, 0][index] = 255
+        frame[y_pos-radius:y_pos+radius, x_pos-radius:x_pos+radius, 1][index] = 0
+        frame[y_pos-radius:y_pos+radius, x_pos-radius:x_pos+radius, 2][index] = 0
+        return frame
+     
+    def __draw_line(self, frame, x, y, phi):
+        length = 20
+        dx = math.sin(phi/360.*2*np.pi) * length
+        dy = math.cos(phi/360.*2*np.pi) * length
+        cv2.line(frame, (int(x-dx),int(y+dy)),
+                 (int(x+dx), int(y-dy)), (250,255,0), 2)
+        return frame
+
+    def grab_frame(self, i):
+        frame = self.data[:,:,:,i]
+        if self.tag is not None:
+            if i in self.tracked_indices:
+                frame = self.__draw_circ(frame, self.x[self.track_counter], 
+                                         self.y[self.track_counter]) 
+                if self.draw_orientation: 
+                    frame = self.__draw_line(frame,self.x[self.track_counter], 
+                                             self.y[self.track_counter],
+                                             self.orientations[self.track_counter])
+                self.track_counter += 1
+        if self.im == None:
+            im = self.axis.imshow(frame)
+        else:
+            im.set_data(frame)
+        return im, 
+
+    def start(self):
+        ani = animation.FuncAnimation(self.figure, self.grab_frame,
+                                      range(1,self.nframes,1), interval=self.interval, 
+                                      repeat=False, blit=True)
+        plt.show()
+
+
+if __name__ == '__main__':
+    import nix
+    import numpy as np
+    import matplotlib
+    import matplotlib.pyplot as plt
+        
+    nix_file = nix.File.open('../data/tracking_data.h5', nix.FileMode.ReadOnly)
+    b = nix_file.blocks[0]
+    video = [a for a in b.data_arrays if a.name == "video"][0]
+    tag = [t for t in b.multi_tags if t.name == "tracking"][0]
+    
+    fig = plt.figure(facecolor='white')
+    pb = Playback(fig, video, tracking_tag=tag, show_orientation=True)
+    pb.start()
+    nix_file.close()