|
@@ -0,0 +1,204 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "0642d602",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import sys, os\n",
|
|
|
+ "sys.path.append(os.path.join(os.getcwd(), '..'))\n",
|
|
|
+ "sys.path.append(os.path.join(os.getcwd(), '..', '..'))\n",
|
|
|
+ "sys.path.append(os.path.join(os.getcwd(), '..', '..'))\n",
|
|
|
+ "sys.path.append(os.path.join(os.getcwd(), '..', '..', 'analysis'))\n",
|
|
|
+ "sys.path.append(os.path.join(os.getcwd(), '..', '..', 'session'))\n",
|
|
|
+ "\n",
|
|
|
+ "import numpy as np\n",
|
|
|
+ "from imports import *\n",
|
|
|
+ "from matplotlib.patches import ConnectionPatch\n",
|
|
|
+ "from scipy.stats import pearsonr"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "3c41fad7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def get_spike_counts(spk_times, pulse_times, hw=0.25, bin_count=51):\n",
|
|
|
+ " collected = []\n",
|
|
|
+ " for t_pulse in pulse_times:\n",
|
|
|
+ " selected = spk_times[(spk_times > t_pulse - hw) & (spk_times < t_pulse + hw)]\n",
|
|
|
+ " collected += [x for x in selected - t_pulse]\n",
|
|
|
+ " collected = np.array(collected)\n",
|
|
|
+ "\n",
|
|
|
+ " bins = np.linspace(-hw, hw, bin_count)\n",
|
|
|
+ " counts, _ = np.histogram(collected, bins=bins)\n",
|
|
|
+ " counts = (counts / len(pulse_times))# * 1/((2. * hw)/float(bin_count - 1))\n",
|
|
|
+ " \n",
|
|
|
+ " return bins, counts"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "id": "1027884b",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def plot_tgt_bgr_psth(example_units):\n",
|
|
|
+ " unit_count = np.array([len(vals) for vals in example_units.values()]).sum()\n",
|
|
|
+ " rows = int(np.ceil(unit_count/3))\n",
|
|
|
+ " fig = plt.figure(figsize=(15, rows*4))\n",
|
|
|
+ " count = 0\n",
|
|
|
+ "\n",
|
|
|
+ " for session, unit_ids in example_units.items():\n",
|
|
|
+ " # read AEP events\n",
|
|
|
+ " animal = session.split('_')[0]\n",
|
|
|
+ " aeps_file = os.path.join(source, animal, session, 'AEPs.h5')\n",
|
|
|
+ " with h5py.File(aeps_file, 'r') as f:\n",
|
|
|
+ " aeps_events = np.array(f['aeps_events'])\n",
|
|
|
+ "\n",
|
|
|
+ " # read single units\n",
|
|
|
+ " spike_times = {}\n",
|
|
|
+ " h5_file = os.path.join(source, animal, session, session + '.h5')\n",
|
|
|
+ " with h5py.File(h5_file, 'r') as f:\n",
|
|
|
+ " cfg = json.loads(f['processed'].attrs['parameters'])\n",
|
|
|
+ " for unit_id in unit_ids:\n",
|
|
|
+ " spike_times[unit_id] = np.array(f['units'][unit_id][H5NAMES.spike_times['name']])\n",
|
|
|
+ "\n",
|
|
|
+ " for unit_id in unit_ids:\n",
|
|
|
+ " bins, counts_bgr = get_spike_counts(spike_times[unit_id], aeps_events[aeps_events[:, 1] == 1][:, 0])\n",
|
|
|
+ " bins, counts_tgt = get_spike_counts(spike_times[unit_id], aeps_events[aeps_events[:, 1] == 2][:, 0])\n",
|
|
|
+ "\n",
|
|
|
+ " ax = fig.add_subplot(rows, 3, count+1)\n",
|
|
|
+ " tgt_dur, bgr_dur = cfg['sound']['sounds']['target']['duration'], cfg['sound']['sounds']['background']['duration']\n",
|
|
|
+ " label_tgt = \"Tgt: %.2f\" % tgt_dur\n",
|
|
|
+ " label_bgr = \"Bgr: %.2f\" % bgr_dur\n",
|
|
|
+ " ax.hist(bins[:-1], bins=bins, weights=counts_tgt, edgecolor='black', color='tab:orange', alpha=0.9, label=label_tgt)\n",
|
|
|
+ " ax.hist(bins[:-1], bins=bins, weights=counts_bgr, edgecolor='black', color='black', alpha=0.5, label=label_bgr)\n",
|
|
|
+ " ax.axvline(0, color='black')\n",
|
|
|
+ " ax.axvline(tgt_dur, color='tab:orange', ls='--')\n",
|
|
|
+ " ax.axvline(tgt_dur - 0.25, color='tab:orange', ls='--')\n",
|
|
|
+ " ax.axvline(bgr_dur, color='black', ls='--', alpha=0.5)\n",
|
|
|
+ " ax.axvline(bgr_dur - 0.25, color='black', ls='--', alpha=0.5)\n",
|
|
|
+ " #ax.set_xlabel('Pulse onset, s', fontsize=14)\n",
|
|
|
+ " ax.axvspan(0, 0.05, alpha=0.3, color='gray')\n",
|
|
|
+ " ax.set_title(\"%s : %s\" % (session[21:], unit_id), fontsize=14)\n",
|
|
|
+ " ax.legend(loc='upper right', prop={'size': 10})\n",
|
|
|
+ " if count % 3 == 0:\n",
|
|
|
+ " ax.set_ylabel(\"Firing Rate, Hz\", fontsize=14)\n",
|
|
|
+ " count += 1\n",
|
|
|
+ " \n",
|
|
|
+ " return fig"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "id": "d97b4e42",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def plot_psth_by_metric(area, m_name, example_units):\n",
|
|
|
+ " unit_count = np.array([len(vals) for vals in example_units.values()]).sum()\n",
|
|
|
+ " rows = int(np.ceil(unit_count/3))\n",
|
|
|
+ " fig = plt.figure(figsize=(15, rows*4))\n",
|
|
|
+ " count = 0\n",
|
|
|
+ "\n",
|
|
|
+ " for session, unit_ids in example_units.items():\n",
|
|
|
+ " # read AEP events\n",
|
|
|
+ " animal = session.split('_')[0]\n",
|
|
|
+ " aeps_file = os.path.join(source, animal, session, 'AEPs.h5')\n",
|
|
|
+ " with h5py.File(aeps_file, 'r') as f:\n",
|
|
|
+ " aeps_events = np.array(f['aeps_events'])\n",
|
|
|
+ " aeps = np.array(f[area]['aeps'])\n",
|
|
|
+ "\n",
|
|
|
+ " # TODO find better way. Remove outliers\n",
|
|
|
+ " aeps[aeps > 5000] = 5000\n",
|
|
|
+ " aeps[aeps < -5000] = -5000\n",
|
|
|
+ "\n",
|
|
|
+ " # read single units\n",
|
|
|
+ " spike_times = {}\n",
|
|
|
+ " h5_file = os.path.join(source, animal, session, session + '.h5')\n",
|
|
|
+ " with h5py.File(h5_file, 'r') as f:\n",
|
|
|
+ " for unit_id in unit_ids:\n",
|
|
|
+ " spike_times[unit_id] = np.array(f['units'][unit_id][H5NAMES.spike_times['name']])\n",
|
|
|
+ "\n",
|
|
|
+ " # load metrics\n",
|
|
|
+ " AEP_metrics_lims = {}\n",
|
|
|
+ " AEP_metrics_raw = {}\n",
|
|
|
+ " AEP_metrics_norm = {}\n",
|
|
|
+ " with h5py.File(aeps_file, 'r') as f:\n",
|
|
|
+ " grp = f[area]\n",
|
|
|
+ " for metric_name in grp['raw']:\n",
|
|
|
+ " AEP_metrics_raw[metric_name] = np.array(grp['raw'][metric_name])\n",
|
|
|
+ " AEP_metrics_norm[metric_name] = np.array(grp['norm'][metric_name])\n",
|
|
|
+ " AEP_metrics_lims[metric_name] = [int(x) for x in grp['raw'][metric_name].attrs['limits'].split(',')]\n",
|
|
|
+ "\n",
|
|
|
+ " # separate high / low AEP metric states\n",
|
|
|
+ " predictor = AEP_metrics_norm[m_name]\n",
|
|
|
+ " low_state_idxs = np.where(predictor < predictor.mean())[0]\n",
|
|
|
+ " high_state_idxs = np.where(predictor > predictor.mean())[0]\n",
|
|
|
+ " aeps_low_mean = aeps[low_state_idxs].mean(axis=0)\n",
|
|
|
+ " aeps_high_mean = aeps[high_state_idxs].mean(axis=0)\n",
|
|
|
+ "\n",
|
|
|
+ " for unit_id in unit_ids:\n",
|
|
|
+ " bins, counts_low = get_spike_counts(spike_times[unit_id], aeps_events[low_state_idxs][:, 0])\n",
|
|
|
+ " bins, counts_high = get_spike_counts(spike_times[unit_id], aeps_events[high_state_idxs][:, 0])\n",
|
|
|
+ "\n",
|
|
|
+ " vals_max = np.array([counts_high.max(), counts_low.max()]).max()\n",
|
|
|
+ " aep_low_profile = (1/10) * vals_max * (aeps_low_mean/500)\n",
|
|
|
+ " aep_high_profile = (1/10) * vals_max * (aeps_high_mean/500)\n",
|
|
|
+ "\n",
|
|
|
+ " ax = fig.add_subplot(rows, 3, count+1)\n",
|
|
|
+ " ax.hist(bins[:-1], bins=bins, weights=counts_high, edgecolor='black', color='red', alpha=0.8, label='%s >' % m_name)\n",
|
|
|
+ " ax.hist(bins[:-1], bins=bins, weights=counts_low, edgecolor='black', color='black', alpha=0.5, label='%s <' % m_name)\n",
|
|
|
+ " for x_l, x_r in [(-0.25, -0.051), (0.0, 0.199)]:\n",
|
|
|
+ " ax.plot(np.linspace(x_l, x_r, len(aeps_low_mean)), aep_high_profile, color='red', lw=2)\n",
|
|
|
+ " ax.plot(np.linspace(x_l, x_r, len(aeps_high_mean)), aep_low_profile, color='black', lw=2)\n",
|
|
|
+ " ax.axvline(0, color='black', ls='--')\n",
|
|
|
+ " #ax.set_xlabel('Pulse onset, s', fontsize=14)\n",
|
|
|
+ " ax.axvspan(0, 0.05, alpha=0.3, color='gray')\n",
|
|
|
+ " ax.set_title(\"%s : %s\" % (session[21:], unit_id), fontsize=14)\n",
|
|
|
+ " ax.legend(loc='upper right', prop={'size': 10})\n",
|
|
|
+ " if count % 3 == 0:\n",
|
|
|
+ " ax.set_ylabel(\"Firing Rate, Hz\", fontsize=14)\n",
|
|
|
+ " count += 1\n",
|
|
|
+ "\n",
|
|
|
+ " return fig"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "04c3c93b",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.10"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|