{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.8"},"colab":{"name":"classifier_for2_likelihood_interpolation_included-Copy1.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"dtPGLizdwlkz","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016461031,"user_tz":-120,"elapsed":3754,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"ebe5fb2f-7436-4afe-cd9b-9e2343fb3515"},"source":["import os\n","#os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n","#os.environ[\"CUDA_VISIBLE_DEVICES\"]= \"7\"\n","import numpy as np\n","import glob\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import random \n","from collections import Counter\n","from sklearn.model_selection import train_test_split\n","from imblearn.over_sampling import RandomOverSampler\n","\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","from torch.utils import data\n","from torch.utils.data import DataLoader\n","from torch.utils.data import Dataset\n","device = 'cuda' \n","from sklearn.utils import shuffle\n","from sklearn.preprocessing import StandardScaler\n","from scipy.interpolate import interp1d\n","from IPython.core.interactiveshell import InteractiveShell\n","InteractiveShell.ast_node_interactivity = \"all\" # allows for multiple outputs per cell to be shown in notebook\n","%matplotlib inline\n","!pip install captum"],"execution_count":32,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: captum in /usr/local/lib/python3.7/dist-packages (0.4.0)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from captum) (1.19.5)\n","Requirement already satisfied: torch>=1.2 in /usr/local/lib/python3.7/dist-packages (from captum) (1.9.0+cu102)\n","Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from captum) (3.2.2)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.2->captum) (3.7.4.3)\n","Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (0.10.0)\n","Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (2.8.2)\n","Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (1.3.1)\n","Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->captum) (2.4.7)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->captum) (1.15.0)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"zorGTnC7GuWM","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016461032,"user_tz":-120,"elapsed":16,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"f41a8a6a-79ab-446b-e3c2-763158bc0f6f"},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":33,"outputs":[{"output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"pUg_ghYjwlk9"},"source":["### Hyperparameters"]},{"cell_type":"code","metadata":{"id":"0zU4GDx9wlk9","executionInfo":{"status":"ok","timestamp":1630016461033,"user_tz":-120,"elapsed":10,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["n_classes = 2\n","lr = 1e-3\n","epochs = 100\n","batch_size = 128\n","episode_len = 100\n","episode_overlap = 80"],"execution_count":34,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XlcY8N1Jwlk9"},"source":["### Data Setup"]},{"cell_type":"code","metadata":{"id":"0uU673yuH_TV","executionInfo":{"status":"ok","timestamp":1630016461034,"user_tz":-120,"elapsed":11,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["videos_fps_file = \"/content/drive/MyDrive/KinematicAnalyses/DLCfiles_prep/fps_file.csv\"\n","dir1 = \"/content/drive/MyDrive/KinematicAnalyses/DLCfiles_prep/MOPstroke_P3\"\n","dir2 = \"/content/drive/MyDrive/KinematicAnalyses/DLCfiles_prep/MOPMOSsham_P3\"\n","out_dir = \"/content/drive/MyDrive/KinematicAnalyses/Classifier_res/\"\n","labels_plot = [\"MOp stroke P3\", \"MOp_MOs sham P3\"]"],"execution_count":35,"outputs":[]},{"cell_type":"code","metadata":{"id":"v9hUFHh5wlk-","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016461034,"user_tz":-120,"elapsed":10,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"5e9a1628-19bd-40a3-cbb5-a474c9977217"},"source":["# make dict with fps for every file {file_name: fps}\n","\n","fps_dict = dict()\n","with open(videos_fps_file, \"r\") as fps_file:\n"," for num, line in enumerate(fps_file):\n"," if num > 1:\n"," line = line.split(\",\")\n"," key = line[1].replace(\"_cropped\", \"\").strip(\"\\\"\")\n"," value = line[2].strip(\"\\\"\")\n"," fps_dict[key] = float(value)\n","\n","# find the highest frame rate in all files that are to compare\n","all_files = glob.glob(dir1 + \"/*_preprocessed_crop.csv\")\n","all_files.extend(glob.glob(dir2 + \"/*_preprocessed_crop.csv\"))\n","print(len(all_files))\n","max_fps = 0\n","for file in all_files:\n"," file = file.split(\"/\")[-1].split(\"_preprocessed\")[0]\n"," if fps_dict[file] > max_fps:\n"," max_fps = fps_dict[file] \n"," \n","# load file and interpolate the number of data points to adjust the frame rate \n","def load_and_interpolate(file_name, fps):\n"," df = pd.read_csv(file_name)\n"," new_df = pd.DataFrame()\n"," if fps < max_fps:\n"," for col in df:\n"," if col.split(\"_\")[0] != \"beam\":\n"," f = interp1d(range(len(df)), df[col], kind='cubic')\n"," index_new = np.linspace(0, len(df)-1, num=round(max_fps*(len(df[col])/fps)), endpoint=True)\n"," new_df[col] = f(index_new)\n"," else:\n"," for col in df:\n"," if col.split(\"_\")[0] != \"beam\":\n"," new_df[col]=df[col]\n"," return(new_df)"],"execution_count":36,"outputs":[{"output_type":"stream","text":["185\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"yqd3VYHVwlk-"},"source":["#### Read group 1"]},{"cell_type":"code","metadata":{"scrolled":true,"id":"FWVY1rS3wlk-","executionInfo":{"status":"ok","timestamp":1630016467188,"user_tz":-120,"elapsed":6161,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["file_directory = dir1\n","file_names = glob.glob(file_directory+\"/*preprocessed_crop.csv\")\n","len_MOPstroke_P3 = []\n","MOPstroke_P3 = []\n","for num, file_name in enumerate(file_names):\n"," file = file_name.split(\"/\")[-1].split(\"_preprocessed\")[0]\n"," df = load_and_interpolate(file_name, fps_dict[file])\n"," len_MOPstroke_P3.append(len(df))\n"," MOPstroke_P3.append(df.values)\n"],"execution_count":37,"outputs":[]},{"cell_type":"code","metadata":{"id":"SSkeAwrHwlk-","colab":{"base_uri":"https://localhost:8080/","height":384},"executionInfo":{"status":"ok","timestamp":1630016467195,"user_tz":-120,"elapsed":27,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"585650e6-c21f-4a22-abb5-4394aaf342d8"},"source":["print(\"# videos in MOP Group:\", len(len_MOPstroke_P3))\n","print(\"shortest video in MOP Group is of length\", min(len_MOPstroke_P3))\n","print(\"longest video in MOP Group is of length\", max(len_MOPstroke_P3))\n","plt.hist(len_MOPstroke_P3)\n","plt.show()"],"execution_count":38,"outputs":[{"output_type":"stream","text":["# videos in MOP Group: 102\n","shortest video in MOP Group is of length 172\n","longest video in MOP Group is of length 2224\n"],"name":"stdout"},{"output_type":"execute_result","data":{"text/plain":["(array([64., 25., 6., 3., 2., 0., 0., 0., 1., 1.]),\n"," array([ 172. , 377.2, 582.4, 787.6, 992.8, 1198. , 1403.2, 1608.4,\n"," 1813.6, 2018.8, 2224. ]),\n"," )"]},"metadata":{},"execution_count":38},{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANxElEQVR4nO3dXYxc513H8e8PO2lRGojdLJYVR6yhVivf5EWrkKpRJBKapgnCRqqiVAhWxZJvWpQKELj0pkhcOEi0BAlVMklgQaFJlDay1Yi2xqSqkMDtunmPG+wER7Xlly1JaMpFS9I/F3NMVut1dnZ3ZieP9/uRRnPOc57x+c+jo5/OPnPOcaoKSVJ7fmbUBUiSlsYAl6RGGeCS1CgDXJIaZYBLUqPWruTOLr/88hofH1/JXUpS8w4dOvSDqhqb276iAT4+Ps709PRK7lKSmpfk5fnanUKRpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGreidmMsxvuuxkez32O7bR7JfSVqIZ+CS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRvUV4EkuS/JIku8lOZzkg0nWJ9mf5Ej3vm7YxUqS3tLvGfg9wNeq6gPAVcBhYBdwoKq2AAe6dUnSClkwwJP8PHAjcB9AVf2kql4DtgFTXbcpYPuwipQknaufM/DNwAzwt0meSHJvkkuADVV1sutzCtgwrCIlSefqJ8DXAtcCX6yqa4D/Yc50SVUVUPN9OMnOJNNJpmdmZpZbrySp00+AHweOV9XBbv0ReoF+OslGgO79zHwfrqo9VTVRVRNjY2ODqFmSRB8BXlWngO8neX/XdDPwPLAPmOzaJoG9Q6lQkjSvfp8H/nvAA0kuBl4CPkEv/B9OsgN4GbhjOCVKkubTV4BX1ZPAxDybbh5sOZKkfnknpiQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVFr++mU5BjwOvAm8EZVTSRZDzwEjAPHgDuq6tXhlClJmmsxZ+C/WlVXV9VEt74LOFBVW4AD3bokaYUsZwplGzDVLU8B25dfjiSpX/0GeAHfSHIoyc6ubUNVneyWTwEb5vtgkp1JppNMz8zMLLNcSdJZfc2BAzdU1YkkvwDsT/K92RurqpLUfB+sqj3AHoCJiYl5+0iSFq+vM/CqOtG9nwEeBa4DTifZCNC9nxlWkZKkcy0Y4EkuSXLp2WXgFuBZYB8w2XWbBPYOq0hJ0rn6mULZADya5Gz/f6yqryX5DvBwkh3Ay8AdwytTkjTXggFeVS8BV83T/l/AzcMoSpK0MO/ElKRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRfQd4kjVJnkjy1W59c5KDSY4meSjJxcMrU5I012LOwO8CDs9avxv4QlW9D3gV2DHIwiRJb6+vAE+yCbgduLdbD3AT8EjXZQrYPowCJUnz6/cM/C+BPwJ+2q2/F3itqt7o1o8DV8z3wSQ7k0wnmZ6ZmVlWsZKktywY4El+HThTVYeWsoOq2lNVE1U1MTY2tpR/QpI0j7V99PkQ8BtJbgPeDfwccA9wWZK13Vn4JuDE8MqUJM214Bl4VX2mqjZV1ThwJ/AvVfVbwOPAx7puk8DeoVUpSTrHcq4D/2Pg95McpTcnft9gSpIk9aOfKZT/V1XfBL7ZLb8EXDf4kiRJ/fBOTElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMWDPAk707y7SRPJXkuyZ927ZuTHExyNMlDSS4efrmSpLP6OQP/MXBTVV0FXA3cmuR64G7gC1X1PuBVYMfwypQkzbVggFfPj7rVi7pXATcBj3TtU8D2oVQoSZpXX3PgSdYkeRI4A+wHXgReq6o3ui7HgSvO89mdSaaTTM/MzAyiZkkSfQZ4Vb1ZVVcDm4DrgA/0u4Oq2lNVE1U1MTY2tsQyJUlzLeoqlKp6DXgc+CBwWZK13aZNwIkB1yZJehv9XIUyluSybvlngQ8Dh+kF+ce6bpPA3mEVKUk619qFu7ARmEqyhl7gP1xVX03yPPBgkj8DngDuG2KdkqQ5FgzwqnoauGae9pfozYdLkkbAOzElqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNaqfZ6GsauO7HhvZvo/tvn1k+5b0zucZuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIatWCAJ7kyyeNJnk/yXJK7uvb1SfYnOdK9rxt+uZKks/o5A38D+IOq2gpcD3wyyVZgF3CgqrYAB7p1SdIKWTDAq+pkVX23W34dOAxcAWwDprpuU8D2YRUpSTrXoubAk4wD1wAHgQ1VdbLbdArYcJ7P7EwynWR6ZmZmGaVKkmbrO8CTvAf4MvDpqvrh7G1VVUDN97mq2lNVE1U1MTY2tqxiJUlv6SvAk1xEL7wfqKqvdM2nk2zstm8EzgynREnSfPq5CiXAfcDhqvr8rE37gMlueRLYO/jyJEnn08//ifkh4LeBZ5I82bX9CbAbeDjJDuBl4I7hlChJms+CAV5V/wrkPJtvHmw5kqR+eSemJDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY1aMMCT3J/kTJJnZ7WtT7I/yZHufd1wy5QkzdXPGfjfAbfOadsFHKiqLcCBbl2StIIWDPCq+hbwypzmbcBUtzwFbB9wXZKkBSx1DnxDVZ3slk8BG87XMcnOJNNJpmdmZpa4O0nSXMv+EbOqCqi32b6nqiaqamJsbGy5u5MkdZYa4KeTbATo3s8MriRJUj+WGuD7gMlueRLYO5hyJEn96ucywi8B/wa8P8nxJDuA3cCHkxwBfq1blyStoLULdaiqj59n080DrkWStAjeiSlJjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMWvA5cozO+67GR7PfY7ttHsl9Ji+MZuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNco7MXWOUd0BCt4FKi2GZ+CS1CgDXJIa5RSK3lF8gJfUP8/AJalRBrgkNcoAl6RGLWsOPMmtwD3AGuDeqto9kKokXdBGeanqKAzrN5Yln4EnWQP8NfBRYCvw8SRbB1WYJOntLWcK5TrgaFW9VFU/AR4Etg2mLEnSQpYzhXIF8P1Z68eBX5nbKclOYGe3+qMkLyxjn627HPjBqItowIqPU+5eyb0NjMdTf0Y+TgM4vn5xvsahXwdeVXuAPcPeTwuSTFfVxKjreKdznPrjOPXnQh6n5UyhnACunLW+qWuTJK2A5QT4d4AtSTYnuRi4E9g3mLIkSQtZ8hRKVb2R5FPA1+ldRnh/VT03sMouTE4l9cdx6o/j1J8LdpxSVaOuQZK0BN6JKUmNMsAlqVEG+AAlOZbkmSRPJpnu2tYn2Z/kSPe+rmtPkr9KcjTJ00muHW31w5Pk/iRnkjw7q23R45Jksut/JMnkKL7LsJ1nrD6X5ER3XD2Z5LZZ2z7TjdULST4yq/3Wru1okl0r/T2GKcmVSR5P8nyS55Lc1bWvvmOqqnwN6AUcAy6f0/bnwK5ueRdwd7d8G/BPQIDrgYOjrn+I43IjcC3w7FLHBVgPvNS9r+uW1436u63QWH0O+MN5+m4FngLeBWwGXqR3QcGabvmXgIu7PltH/d0GOEYbgWu75UuB/+jGYtUdU56BD982YKpbngK2z2r/++r5d+CyJBtHUeCwVdW3gFfmNC92XD4C7K+qV6rqVWA/cOvwq19Z5xmr89kGPFhVP66q/wSO0nvExQX9mIuqOllV3+2WXwcO07szfNUdUwb4YBXwjSSHukcIAGyoqpPd8ilgQ7c836MIrliZMt8RFjsuq328PtX9+X//2akBHCuSjAPXAAdZhceUAT5YN1TVtfSe0PjJJDfO3li9v9u8bnMOx2VBXwR+GbgaOAn8xWjLeWdI8h7gy8Cnq+qHs7etlmPKAB+gqjrRvZ8BHqX3p+zps1Mj3fuZrvtqfxTBYsdl1Y5XVZ2uqjer6qfA39A7rmAVj1WSi+iF9wNV9ZWuedUdUwb4gCS5JMmlZ5eBW4Bn6T1e4Oyv25PA3m55H/A73S/k1wP/PevPv9VgsePydeCWJOu6KYRburYL3pzfRn6T3nEFvbG6M8m7kmwGtgDf5gJ/zEWSAPcBh6vq87M2rb5jatS/ol4oL3q/+D/VvZ4DPtu1vxc4ABwB/hlY37WH3n+I8SLwDDAx6u8wxLH5Er0//f+X3jzjjqWMC/C79H6oOwp8YtTfawXH6h+6sXiaXhhtnNX/s91YvQB8dFb7bfSuznjx7LF4obyAG+hNjzwNPNm9bluNx5S30ktSo5xCkaRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUf8HRQGdJFmog8sAAAAASUVORK5CYII=\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"t5nTSM5qwlk_"},"source":["#### Read group 2"]},{"cell_type":"code","metadata":{"scrolled":true,"id":"WDINeTZCwlk_","executionInfo":{"status":"ok","timestamp":1630016472569,"user_tz":-120,"elapsed":5383,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["file_directory = dir2\n","file_names = glob.glob(file_directory+\"/*preprocessed_crop.csv\")\n","len_sham_P3 = []\n","sham_P3 = []\n","for num, file_name in enumerate(file_names):\n"," file = file_name.split(\"/\")[-1].split(\"_preprocessed\")[0]\n"," df = load_and_interpolate(file_name, fps_dict[file])\n"," len_sham_P3.append(len(df))\n"," sham_P3.append(df.values)"],"execution_count":39,"outputs":[]},{"cell_type":"code","metadata":{"id":"WBALd-TCwlk_","colab":{"base_uri":"https://localhost:8080/","height":384},"executionInfo":{"status":"ok","timestamp":1630016472570,"user_tz":-120,"elapsed":42,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"7544030e-0706-4ef3-8f68-dad0759295c4"},"source":["print(\"# videos in MOP Group:\", len(len_sham_P3))\n","print(\"shortest video in MOP Group is of length\", min(len_sham_P3))\n","print(\"longest video in MOP Group is of length\", max(len_sham_P3))\n","plt.hist(len_sham_P3)\n","plt.show()"],"execution_count":40,"outputs":[{"output_type":"stream","text":["# videos in MOP Group: 83\n","shortest video in MOP Group is of length 154\n","longest video in MOP Group is of length 518\n"],"name":"stdout"},{"output_type":"execute_result","data":{"text/plain":["(array([ 9., 14., 22., 12., 8., 1., 4., 5., 2., 6.]),\n"," array([154. , 190.4, 226.8, 263.2, 299.6, 336. , 372.4, 408.8, 445.2,\n"," 481.6, 518. ]),\n"," )"]},"metadata":{},"execution_count":40},{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM9UlEQVR4nO3df6zd9V3H8edbQGYGcXS9aZoBXrYQTWO0a66IGSE43OSHsSwhBuKP/kFSo5BsUaPFJYp/mHQm29TEbHZSqW5jm9sIzXBzWEiIiTJvtwJliO22LtKU3iLZxv6ZAm//OJ+7nt3e23N7z7nn+33L85Gc3O/38/223xefm/Pie77nfHsiM5Ek1fNDXQeQJK2NBS5JRVngklSUBS5JRVngklTU+dM82MaNG3N2dnaah5Sk8g4ePPhCZs4sHZ9qgc/OzjI/Pz/NQ0pSeRHxzeXGvYQiSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVN9U5MnZvZXQ91ctxju2/u5LiSzo1n4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUVZ4JJUlAUuSUWNLPCIuCwiHo2Ir0bE0xHx7ja+ISIejogj7ecl6x9XkrRoNWfgLwO/m5lbgKuBOyNiC7ALOJCZVwIH2rokaUpGFnhmnsjML7fll4BngDcB24F9bbd9wC3rFVKSdKZzugYeEbPAW4HHgU2ZeaJteh7YNNFkkqSzWnWBR8RFwGeA92Tmd4a3ZWYCucKf2xkR8xExf+rUqbHCSpJOW1WBR8QFDMr7Y5n52TZ8MiI2t+2bgYXl/mxm7snMucycm5mZmURmSRKr+xRKAPcCz2TmB4Y27Qd2tOUdwIOTjydJWslqvpX+bcCvA09FxKE29ofAbuBTEXEH8E3gV9YnoiRpOSMLPDP/BYgVNl8/2TiSpNXyTkxJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiLHBJKsoCl6SiRhZ4ROyNiIWIODw0dk9EHI+IQ+1x0/rGlCQttZoz8PuAG5YZ/2Bmbm2Pf5xsLEnSKCMLPDMfA16cQhZJ0jkY5xr4XRHxZLvEcsnEEkmSVmWtBf4h4C3AVuAE8P6VdoyInRExHxHzp06dWuPhJElLranAM/NkZr6Sma8CHwGuOsu+ezJzLjPnZmZm1ppTkrTEmgo8IjYPrb4LOLzSvpKk9XH+qB0i4n7gOmBjRDwH/DFwXURsBRI4BvzmOmaUJC1jZIFn5u3LDN+7DlkkSefAOzElqSgLXJKKssAlqSgLXJKKGvkmpmB210NdR5CkM3gGLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJQFLklFWeCSVJRfqaYzdPkVcsd239zZsaVqPAOXpKIscEkqygKXpKIscEkqygKXpKIscEkqygKXpKIscEkqygKXpKIscEkqygKXpKIscEkqygKXpKIscEkqygKXpKJGFnhE7I2IhYg4PDS2ISIejogj7ecl6xtTkrTUas7A7wNuWDK2CziQmVcCB9q6JGmKRhZ4Zj4GvLhkeDuwry3vA26ZcC5J0ghrvQa+KTNPtOXngU0r7RgROyNiPiLmT506tcbDSZKWGvtNzMxMIM+yfU9mzmXm3MzMzLiHkyQ1ay3wkxGxGaD9XJhcJEnSaqy1wPcDO9ryDuDBycSRJK3Waj5GeD/wr8CPR8RzEXEHsBt4R0QcAX6hrUuSpuj8UTtk5u0rbLp+wlkkSefAOzElqSgLXJKKssAlqaiR18D7YnbXQ11HkKRe8QxckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpqDI38ui1oasbto7tvrmT40rj8AxckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKL9STdJrRldf2Qfr87V9noFLUlEWuCQVZYFLUlEWuCQVZYFLUlEWuCQVNdbHCCPiGPAS8ArwcmbOTSKUJGm0SXwO/Ocz84UJ/D2SpHPgJRRJKmrcM/AEvhgRCfx1Zu5ZukNE7AR2Alx++eVjHk7SpPx/uyvxtWjcM/BrMnMbcCNwZ0Rcu3SHzNyTmXOZOTczMzPm4SRJi8Yq8Mw83n4uAA8AV00ilCRptDUXeES8PiIuXlwG3gkcnlQwSdLZjXMNfBPwQEQs/j0fz8wvTCSVJGmkNRd4Zn4d+OkJZpEknQM/RihJRVngklSUBS5JRfmValLHuryhRrV5Bi5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUBS5JRVngklSUX6kmaer8GrnJ8AxckoqywCWpKAtckoqywCWpKAtckoqywCWpKAtckoqywCWpKG/kkfDGEtXkGbgkFWWBS1JRFrgkFWWBS1JRFrgkFWWBS1JRYxV4RNwQEc9GxNGI2DWpUJKk0dZc4BFxHvBXwI3AFuD2iNgyqWCSpLMb5wz8KuBoZn49M/8H+ASwfTKxJEmjjHMn5puA/xpafw742aU7RcROYGdb/W5EPHsOx9gIvLDmhNNTIWeFjGDOSTPn5IyVMd431rF/bLnBdb+VPjP3AHvW8mcjYj4z5yYcaeIq5KyQEcw5aeacnD5mHOcSynHgsqH1S9uYJGkKxinwfweujIgrIuKHgduA/ZOJJUkaZc2XUDLz5Yi4C/gn4Dxgb2Y+PbFkA2u69NKBCjkrZARzTpo5J6d3GSMzu84gSVoD78SUpKIscEkqqrMCj4i9EbEQEYeHxu6JiOMRcag9bhradne7Zf/ZiPjFKea8LCIejYivRsTTEfHuNr4hIh6OiCPt5yVtPCLiL1vWJyNiW8c5ezWnEfG6iPhSRDzRcv5JG78iIh5veT7Z3hgnIi5s60fb9tkOM94XEd8YmsutbbyT3/lQ3vMi4isR8bm23pu5HJGzd/MZEcci4qmWZ76N9eq5/gMys5MHcC2wDTg8NHYP8HvL7LsFeAK4ELgC+Bpw3pRybga2teWLgf9sef4M2NXGdwHva8s3AZ8HArgaeLzjnL2a0zYvF7XlC4DH2zx9CritjX8Y+K22/NvAh9vybcAnO8x4H3DrMvt38jsfOv7vAB8HPtfWezOXI3L2bj6BY8DGJWO9eq4PPzo7A8/Mx4AXV7n7duATmfm9zPwGcJTBrfzrLjNPZOaX2/JLwDMM7kLdDuxru+0DbhnK+nc58G/AGyJic4c5V9LJnLZ5+W5bvaA9Eng78Ok2vnQ+F+f508D1EREdZVxJJ79zgIi4FLgZ+Ju2HvRoLlfKOUJn83mWPL15rg/r4zXwu9rLkb2LL1VY/rb9s5XTumgvOd/K4IxsU2aeaJueBza15c6zLskJPZvT9lL6ELAAPMzg7P9bmfnyMlm+n7Nt/zbwxmlnzMzFufzTNpcfjIgLl2ZcJv96+3Pg94FX2/ob6dlcrpBzUd/mM4EvRsTBGPwzINDj53rfCvxDwFuArcAJ4P3dxjktIi4CPgO8JzO/M7wtB6+nevF5zGVy9m5OM/OVzNzK4O7dq4Cf6DjSGZZmjIifBO5mkPVngA3AH3QYkYj4JWAhMw92mWOUs+Ts1Xw212TmNgb/yuqdEXHt8MY+PdehZwWemSfbE+dV4COcfknf6W37EXEBg1L8WGZ+tg2fXHy51H4udJ11uZx9ndOW7VvAo8DPMXj5uXhj2XCW7+ds238U+O8OMt7QLlNlZn4P+Fu6n8u3Ab8cEccY/Gugbwf+gv7N5Rk5I+KjPZxPMvN4+7kAPNAy9e65vqhXBb7k+tG7gMVPqOwHbmvvol8BXAl8aUqZArgXeCYzPzC0aT+woy3vAB4cGv+N9g711cC3h15+TT1n3+Y0ImYi4g1t+UeAdzC4Xv8ocGvbbel8Ls7zrcAj7Sxo2hn/Y+hJHAyugw7P5dR/55l5d2ZempmzDN6UfCQzf5UezeVZcv5a3+YzIl4fERcvLgPvbJl69Vz/AdN4p3S5B3A/g5f0/8vg2tEdwN8DTwFPtsnZPLT/exlcK30WuHGKOa9h8JLpSeBQe9zE4NrhAeAI8M/AhrZ/MPiii6+1/5a5jnP2ak6BnwK+0vIcBv6ojb+Zwf9AjgL/AFzYxl/X1o+27W/uMOMjbS4PAx/l9CdVOvmdL8l8Hac/3dGbuRyRs1fz2ebtifZ4GnhvG+/Vc3344a30klRUry6hSJJWzwKXpKIscEkqygKXpKIscEkqygKXpKIscEkq6v8AwZQTWEXD8hkAAAAASUVORK5CYII=\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"pIl1eDudwlk_"},"source":["Suppose we have an array with 3 rows and 6 columns. Columns indicate x1, y1, likelihood1, x2, y2, likelihood2 (like the data structure we have). By indexing we can separate x, y and likelihood columns. Later we use this idea in MyData_Q.\n","\n","a = np.array([[1,2,3,10,13,16],[4,5,6,11,14,17], [7,8,9,12,15,18]])\n","\n","\n"," - a[:,0::3] returns x1, x2\n","\n"," - a[:,1::3] returns y1, y2\n","\n"," - a[:,2::3] returns likelihood1, likelihood2\n"]},{"cell_type":"markdown","metadata":{"id":"6fhP34zxwlk_"},"source":["### Slice each video into episodes of length n"]},{"cell_type":"code","metadata":{"id":"FjuBQCddwllA","executionInfo":{"status":"ok","timestamp":1630016472571,"user_tz":-120,"elapsed":32,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["def sliced_vidoes(input_, n=episode_len, m=episode_overlap):\n"," #input is a list of lists\n"," #n: group size\n"," #m: overlap size\n"," input_sliced = []\n"," for k in range(len(input_)):\n"," input_sliced.extend([input_[k][i:i+n] for i in range(0, len(input_[k]), n-m)][:-n//(n-m)])\n"," return input_sliced"],"execution_count":41,"outputs":[]},{"cell_type":"code","metadata":{"id":"QYISTwltwllA","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016472572,"user_tz":-120,"elapsed":31,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"a8b2e561-81be-4ba0-851d-63b651ad0b7c"},"source":["stroke_0 = sliced_vidoes(MOPstroke_P3)\n","print(\"number of episodes for MOPstroke_P3 group:\", len(stroke_0))\n","stroke_1 = sliced_vidoes(sham_P3)\n","print(\"number of episodes for sham_P3 group:\", len(stroke_1))"],"execution_count":42,"outputs":[{"output_type":"stream","text":["number of episodes for MOPstroke_P3 group: 1630\n","number of episodes for sham_P3 group: 805\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"zw7BcB-hwllA"},"source":["### split the dataset into training and test sets"]},{"cell_type":"code","metadata":{"id":"YXX3Jc4MwllA","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016472573,"user_tz":-120,"elapsed":24,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"dc1ebc7d-ca07-4d6f-cd75-8b1a197c5ed4"},"source":["all_data = stroke_0 + stroke_1\n","all_labels = len(stroke_0)*[0] + len(stroke_1)*[1]\n","\n","data_lst, labels = shuffle(all_data, all_labels, random_state=10)\n","\n","# I pass an int to random_state such that the output is reproducible across multiple function calls\n","X_train, X_test, label_train, label_test = train_test_split(data_lst, labels, \n"," test_size=0.1, \n"," random_state=0\n"," )\n","\n","\n","print('len(train_set):', len(X_train))\n","print('len(test_set):', len(X_test))\n","print(\"training:\", sorted(Counter(label_train).items()))\n","print(\"test:\", sorted(Counter(label_test).items()))\n","#print(X_train)"],"execution_count":43,"outputs":[{"output_type":"stream","text":["len(train_set): 2191\n","len(test_set): 244\n","training: [(0, 1459), (1, 732)]\n","test: [(0, 171), (1, 73)]\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"IZZ8EohAXQlO","executionInfo":{"status":"ok","timestamp":1630016472574,"user_tz":-120,"elapsed":19,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["sc_x = StandardScaler()\n","train_concat = np.concatenate(X_train, axis=0)\n","test_concat = np.concatenate(X_test, axis=0)\n","train_scaled = sc_x.fit_transform(train_concat)\n","test_scaled = sc_x.transform(test_concat)\n","train_scaled[:, 2::3] = train_concat[:, 2::3]\n","test_scaled[:, 2::3] = test_concat [:, 2::3]\n","X_train = train_scaled.reshape(len(X_train), 100, 45)\n","X_test = test_scaled.reshape(len(X_test), 100, 45)"],"execution_count":44,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"O467s0XFwllA"},"source":["### balance the classes in each batch"]},{"cell_type":"markdown","metadata":{"id":"HrkvEC7UwllB"},"source":["If you have class imbalance, use a weighted sampler, so that you see all classes with equal probability\n","\n","The weighted sampling should only be used for training, to balance the classes in each batch, which hopefully helps the training.\n","\n","The validation and test accuracy is calculated on the complete datasets without any sampling (shuffle is also not needed here, since the order of the data won’t change the metrics)."]},{"cell_type":"code","metadata":{"id":"bVgaxbc20rw7","executionInfo":{"status":"ok","timestamp":1630016472574,"user_tz":-120,"elapsed":18,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["class MyData_Q(Dataset):\n","\n"," def __init__(self, train_data, train_label):\n"," self.train_data = train_data\n"," self.train_label = train_label\n","\n"," def __len__(self):\n"," return len(self.train_data)\n","\n"," def __getitem__(self, item):\n"," seq_len = len(self.train_data[item])\n"," X = self.train_data[item][:, 0::3]\n"," Y = self.train_data[item][:, 1::3]\n"," LH = self.train_data[item][:, 2::3] # likelihood\n"," LH_T = np.transpose(LH, (1, 0))\n"," \n"," #likelihood<0.6 -->quality index=0\n"," #0.6<=likelihood<0.8 -->quality index=1\n"," #0.8<=likelihood<=1 --> quality index=2\n"," \n"," quality_indices = 2 * np.logical_and(0.8<=LH_T, LH_T<=1.0) + 1 * np.logical_and(0.6<=LH_T, LH_T<0.8) \n"," return np.transpose(X, (1, 0)), np.transpose(Y, (1, 0)), quality_indices, self.train_label[item]\n"],"execution_count":45,"outputs":[]},{"cell_type":"code","metadata":{"id":"wvD_CXJ_wllB","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016472575,"user_tz":-120,"elapsed":19,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"be16132c-b624-46c1-cefa-d61cc2532a98"},"source":["def Sampler(w_labels):\n"," count = Counter(w_labels) \n"," print(count)\n"," class_count = torch.from_numpy(np.array([count[i] for i in range(len(count))])) \n"," weight = 1./class_count \n"," samples_weight = np.array([weight[t] for t in w_labels])\n"," samples_weight=torch.from_numpy(samples_weight) # a sequence of weights, not necessary summing up to one\n"," #Sample elements from [0,..,len(weights)-1] with given probabilities (weights).\n"," sampler = torch.utils.data.WeightedRandomSampler(samples_weight, \n"," len(samples_weight), \n"," replacement=True,) \n"," return sampler\n","\n","\n","train_sampler = Sampler(label_train)\n","train_data = MyData_Q(X_train, label_train)\n","train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)\n","\n","test_data = MyData_Q(X_test, label_test)\n","test_loader = DataLoader(test_data, batch_size=len(test_data))"],"execution_count":46,"outputs":[{"output_type":"stream","text":["Counter({0: 1459, 1: 732})\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"-Uxj2wkmtdUo","executionInfo":{"status":"ok","timestamp":1630016472576,"user_tz":-120,"elapsed":16,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["x, y, q, t = next(iter(train_loader))"],"execution_count":47,"outputs":[]},{"cell_type":"code","metadata":{"id":"I8yT6_SuwllB","executionInfo":{"status":"ok","timestamp":1630016473490,"user_tz":-120,"elapsed":45,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["class classifier_conv(nn.Module):\n"," def __init__(self, in_channel, n_classes, hidden_unit=128):\n"," super(classifier_conv, self).__init__()\n"," self.n_classes = n_classes\n"," self.hidden_unit = hidden_unit\n"," self.encoder = nn.Sequential(\n"," nn.Conv1d(in_channel, hidden_unit, kernel_size=10, stride=5, padding=3, bias=False),\n"," nn.BatchNorm1d(hidden_unit),\n"," nn.ReLU(inplace=True),\n"," nn.Conv1d(hidden_unit, hidden_unit, kernel_size=8, stride=4, padding=2, bias=False),\n"," nn.BatchNorm1d(hidden_unit),\n"," nn.ReLU(inplace=True))\n"," self.fc = nn.Linear(hidden_unit*5, self.n_classes)\n"," \n"," def _weights_init(m):\n"," if isinstance(m, nn.Linear):\n"," nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n"," if isinstance(m, nn.Conv1d):\n"," nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n"," elif isinstance(m, nn.BatchNorm1d):\n"," nn.init.constant_(m.weight, 1)\n"," nn.init.constant_(m.bias, 0)\n","\n"," self.apply(_weights_init)\n","\n"," def forward(self, x):\n"," bs = x.size(0) \n"," x = self.encoder(x) #[bs, 38, 100] [bs, input_channels, L=100]\n"," #print(x.size()) # [bs, 128, 5]\n"," x = self.fc(x.view(bs, -1)) #[bs, 128*5]\n"," return x"],"execution_count":48,"outputs":[]},{"cell_type":"code","metadata":{"id":"k8NNGh-WwllC","executionInfo":{"status":"ok","timestamp":1630016473492,"user_tz":-120,"elapsed":44,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["##\n","x, y, q, t = next(iter(test_loader))\n","#print(q[100,0,-35:])\n","\n","nq = 3\n","onehot_quality = F.one_hot(q.view(1,-1).to(torch.int64), nq).view(q.size(0), q.size(1), q.size(2), nq)\n","onehot_quality = onehot_quality.permute(0, 1, 3, 2)\n","bs, nf, _, l = onehot_quality.size()\n","\n","X_encoded = torch.mul(x.unsqueeze(2).float(), onehot_quality.float()).reshape(bs, nf*nq, l)\n","#print(X_encoded[100,0:3,-35:])"],"execution_count":49,"outputs":[]},{"cell_type":"code","metadata":{"id":"c_0gwQ1CwllC","executionInfo":{"status":"ok","timestamp":1630016473493,"user_tz":-120,"elapsed":44,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["class classifier_fc(nn.Module):\n"," def __init__(self, in_channel, n_classes, hidden_unit=32):\n"," super(classifier_fc, self).__init__()\n"," self.n_classes = n_classes\n"," self.hidden_unit = hidden_unit\n"," self.encoder = nn.Sequential(\n"," nn.Linear(in_channel, hidden_unit),\n"," nn.ReLU(inplace=True),\n"," nn.Linear(hidden_unit, hidden_unit//2),\n"," nn.ReLU(inplace=True),)\n"," self.fc = nn.Linear(hidden_unit//2, self.n_classes)\n"," \n"," def _weights_init(m):\n"," if isinstance(m, nn.Linear):\n"," nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n"," if isinstance(m, nn.BatchNorm1d):\n"," nn.init.constant_(m.weight, 1)\n"," nn.init.constant_(m.bias, 0)\n","\n"," self.apply(_weights_init)\n","\n"," def forward(self, x):\n"," bs = x.size(0) \n"," x = self.encoder(x.view(bs, -1)) #[bs, input_channels*L]\n"," x = self.fc(x.view(bs, -1)) \n"," return x"],"execution_count":50,"outputs":[]},{"cell_type":"code","metadata":{"id":"LHt8adetwllC","executionInfo":{"status":"ok","timestamp":1630016473494,"user_tz":-120,"elapsed":43,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["class E2EStateClassifier(torch.nn.Module):\n"," def __init__(self, hidden_size, in_channel, encoding_size, output_size, cell_type='GRU', num_layers=1, dropout=0,\n"," bidirectional=True, device=device):\n"," super(E2EStateClassifier, self).__init__()\n"," self.hidden_size = hidden_size\n"," self.in_channel = in_channel\n"," self.num_layers = num_layers\n"," self.cell_type = cell_type\n"," self.encoding_size = encoding_size\n"," self.bidirectional = bidirectional\n"," self.output_size = output_size\n"," self.device = device\n","\n"," self.fc = torch.nn.Sequential(torch.nn.Linear(self.hidden_size*(int(self.bidirectional) + 1), self.encoding_size)).to(self.device)\n"," self.nn = torch.nn.Sequential(torch.nn.Linear(self.encoding_size, self.output_size)).to(self.device)\n"," if cell_type=='GRU':\n"," self.rnn = torch.nn.GRU(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,\n"," batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)\n"," elif cell_type=='LSTM':\n"," self.rnn = torch.nn.LSTM(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,\n"," batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)\n"," else:\n"," raise ValueError('Cell type not defined, must be one of the following {GRU, LSTM, RNN}')\n","\n"," def forward(self, x):\n"," x = x.permute(2,0,1)\n"," if self.cell_type=='GRU':\n"," past = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), x.shape[1], self.hidden_size).to(self.device)\n"," elif self.cell_type=='LSTM':\n"," h_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), (x.shape[1]), self.hidden_size).to(self.device)\n"," c_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), (x.shape[1]), self.hidden_size).to(self.device)\n"," past = (h_0, c_0)\n"," out, _ = self.rnn(x, past) # out shape = [seq_len, batch_size, num_directions*hidden_size]\n"," encodings = self.fc(out[-1].squeeze(0))\n"," return self.nn(encodings)"],"execution_count":51,"outputs":[]},{"cell_type":"code","metadata":{"id":"6c-MHRkdwllD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016473495,"user_tz":-120,"elapsed":42,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"e68cdca9-9c28-4482-d781-43572074e816"},"source":["#Classifier = classifier_fc((15+15)*3*100, n_classes, hidden_unit=10)\n","#Classifier = classifier_conv(30, n_classes, hidden_unit=128)\n","Classifier = E2EStateClassifier(hidden_size=128, in_channel=(15+15)*3, encoding_size=10, output_size=n_classes, device=device, cell_type = \"GRU\")\n","Classifier = Classifier.to(device)\n","Classifier.train() "],"execution_count":52,"outputs":[{"output_type":"execute_result","data":{"text/plain":["E2EStateClassifier(\n"," (fc): Sequential(\n"," (0): Linear(in_features=256, out_features=10, bias=True)\n"," )\n"," (nn): Sequential(\n"," (0): Linear(in_features=10, out_features=2, bias=True)\n"," )\n"," (rnn): GRU(90, 128, bidirectional=True)\n",")"]},"metadata":{},"execution_count":52}]},{"cell_type":"markdown","metadata":{"id":"4kvYz3XjwllD"},"source":["### Optimizer"]},{"cell_type":"markdown","metadata":{"id":"q7GE-xL5wllD"},"source":["Weight decay is a regularization technique by adding a small penalty, usually the L2 norm of the weights (all the weights of the model), to the loss function.\n","PyTorch applies weight decay to both weights and bias"]},{"cell_type":"code","metadata":{"id":"6uP5xmcpwllD","executionInfo":{"status":"ok","timestamp":1630016473496,"user_tz":-120,"elapsed":37,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["optimizer = optim.Adam(Classifier.parameters(), lr=lr, weight_decay=1e-6) "],"execution_count":53,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ghVTMC7ewllD"},"source":["### Criterion"]},{"cell_type":"code","metadata":{"id":"vR30tnGXwllD","executionInfo":{"status":"ok","timestamp":1630016473496,"user_tz":-120,"elapsed":35,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["criterion = nn.CrossEntropyLoss() "],"execution_count":54,"outputs":[]},{"cell_type":"code","metadata":{"id":"1MAqWmlHwllE","executionInfo":{"status":"ok","timestamp":1630016473497,"user_tz":-120,"elapsed":30,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["def onehot_encode(x, q, nq = 3):\n"," onehot_quality = F.one_hot(q.view(1,-1).to(torch.int64), nq).view(q.size(0), q.size(1), q.size(2), nq)\n"," onehot_quality = onehot_quality.permute(0, 1, 3, 2)\n"," bs, nf, _, l = onehot_quality.size()\n","\n"," x_encoded = torch.mul(x.unsqueeze(2).float(), onehot_quality.float()).reshape(bs, nf*nq, l) \n"," return x_encoded"],"execution_count":55,"outputs":[]},{"cell_type":"code","metadata":{"id":"gx3NCD6zwllE","executionInfo":{"status":"ok","timestamp":1630016473498,"user_tz":-120,"elapsed":29,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["def train(epoch, Classifier, train_loader, device, optimizer):\n"," Classifier.train()\n"," total_loss = 0\n"," total_acc = 0\n"," for i, (x, y, q, target) in enumerate(train_loader):\n"," x, y, q, target = x.to(device), y.to(device), q.to(device), target.to(device) \n"," x_encoded = onehot_encode(x, q)\n"," y_encoded = onehot_encode(y, q)\n"," \n"," input_ = torch.cat((x_encoded, y_encoded), dim=1)\n"," optimizer.zero_grad()\n"," logit = Classifier(input_.float())\n"," \n"," loss = criterion(logit, target)\n"," total_loss+=loss.item()\n"," \n"," loss.backward()\n"," optimizer.step()\n"," \n"," pred = logit.max(1, keepdim=True)[1] # get the index of the max log-probability\n"," acc = 1.*pred.eq(target.view_as(pred)).sum().item()/len(x)\n"," total_acc+=acc\n"," \n"," total_loss = total_loss/float(len(train_loader))\n"," total_acc = total_acc/float(len(train_loader))\n"," print(f'Train epoch {epoch} ----> loss_train:{loss.item()}, accuracy_train:{acc*100}')\n"," return total_loss, total_acc"],"execution_count":56,"outputs":[]},{"cell_type":"code","metadata":{"id":"VRDQ8yAswllE","executionInfo":{"status":"ok","timestamp":1630016473499,"user_tz":-120,"elapsed":28,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["def test(epoch, Classifier, test_loader, device):\n"," Classifier.eval()\n"," with torch.no_grad():\n"," x, y, q, target = next(iter(test_loader))\n"," x, y, q, target = x.to(device), y.to(device), q.to(device), target.to(device)\n"," x_encoded = onehot_encode(x, q)\n"," y_encoded = onehot_encode(y, q)\n"," \n"," input_ = torch.cat((x_encoded, y_encoded), dim=1)\n","\n"," logit = Classifier(input_.float())\n"," loss = criterion(logit, target)\n","\n"," pred = logit.max(1, keepdim=True)[1] # get the index of the max log-probability\n"," acc = 1.*pred.eq(target.view_as(pred)).sum().item()/len(x)\n","\n"," print(f'Test epoch {epoch} ----> loss_test:{loss.item()}, accuracy_test:{acc*100}') \n"," return loss, acc "],"execution_count":57,"outputs":[]},{"cell_type":"code","metadata":{"tags":[],"id":"CSNhvkQ4wllE","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016603306,"user_tz":-120,"elapsed":129834,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"73b0de01-825c-4666-bedd-bb7818b88570"},"source":["if not os.path.isdir('checkpoint'):\n"," os.makedirs('checkpoint')\n","\n","train_loss_lst = []\n","train_acc_lst = []\n","test_loss_lst = []\n","test_acc_lst = []\n","for epoch in range(epochs):\n"," train_loss, train_acc = train(epoch, Classifier, train_loader, device, optimizer)\n"," train_loss_lst.append(train_loss)\n"," train_acc_lst.append(train_acc)\n"," test_loss, test_acc = test(epoch, Classifier, test_loader, device)\n"," test_loss_lst.append(test_loss)\n"," test_acc_lst.append(test_acc)\n"," \n"," \n","# torch.save(\n","# {'model': Classifier.state_dict()},\n","# f'checkpoint/classifier.pt',)"],"execution_count":58,"outputs":[{"output_type":"stream","text":["Train epoch 0 ----> loss_train:0.50072181224823, accuracy_train:86.66666666666667\n","Test epoch 0 ----> loss_test:0.5565370917320251, accuracy_test:68.44262295081968\n","Train epoch 1 ----> loss_train:0.5144981145858765, accuracy_train:86.66666666666667\n","Test epoch 1 ----> loss_test:0.4894435703754425, accuracy_test:76.22950819672131\n","Train epoch 2 ----> loss_train:0.638618528842926, accuracy_train:46.666666666666664\n","Test epoch 2 ----> loss_test:0.5202725529670715, accuracy_test:73.36065573770492\n","Train epoch 3 ----> loss_train:0.498960942029953, accuracy_train:80.0\n","Test epoch 3 ----> loss_test:0.501947283744812, accuracy_test:76.63934426229508\n","Train epoch 4 ----> loss_train:0.434619277715683, accuracy_train:86.66666666666667\n","Test epoch 4 ----> loss_test:0.4861672520637512, accuracy_test:78.27868852459017\n","Train epoch 5 ----> loss_train:0.5089713335037231, accuracy_train:80.0\n","Test epoch 5 ----> loss_test:0.43879497051239014, accuracy_test:82.37704918032787\n","Train epoch 6 ----> loss_train:0.5331366062164307, accuracy_train:66.66666666666666\n","Test epoch 6 ----> loss_test:0.41701582074165344, accuracy_test:83.60655737704919\n","Train epoch 7 ----> loss_train:0.556184709072113, accuracy_train:73.33333333333333\n","Test epoch 7 ----> loss_test:0.48609787225723267, accuracy_test:78.27868852459017\n","Train epoch 8 ----> loss_train:0.3908292353153229, accuracy_train:80.0\n","Test epoch 8 ----> loss_test:0.5047057271003723, accuracy_test:80.73770491803278\n","Train epoch 9 ----> loss_train:0.3973643481731415, accuracy_train:73.33333333333333\n","Test epoch 9 ----> loss_test:0.38848280906677246, accuracy_test:84.42622950819673\n","Train epoch 10 ----> loss_train:0.662800669670105, accuracy_train:66.66666666666666\n","Test epoch 10 ----> loss_test:0.4600295424461365, accuracy_test:78.27868852459017\n","Train epoch 11 ----> loss_train:0.6342880725860596, accuracy_train:53.333333333333336\n","Test epoch 11 ----> loss_test:0.4014802873134613, accuracy_test:83.60655737704919\n","Train epoch 12 ----> loss_train:0.23578602075576782, accuracy_train:93.33333333333333\n","Test epoch 12 ----> loss_test:0.4313466548919678, accuracy_test:79.91803278688525\n","Train epoch 13 ----> loss_train:0.2125960737466812, accuracy_train:100.0\n","Test epoch 13 ----> loss_test:0.4224534034729004, accuracy_test:81.55737704918032\n","Train epoch 14 ----> loss_train:0.19668173789978027, accuracy_train:100.0\n","Test epoch 14 ----> loss_test:0.3864031732082367, accuracy_test:85.65573770491804\n","Train epoch 15 ----> loss_train:0.3244040906429291, accuracy_train:86.66666666666667\n","Test epoch 15 ----> loss_test:0.425106406211853, accuracy_test:81.9672131147541\n","Train epoch 16 ----> loss_train:0.3993677496910095, accuracy_train:80.0\n","Test epoch 16 ----> loss_test:0.4017372131347656, accuracy_test:80.73770491803278\n","Train epoch 17 ----> loss_train:0.23289324343204498, accuracy_train:93.33333333333333\n","Test epoch 17 ----> loss_test:0.387134850025177, accuracy_test:84.8360655737705\n","Train epoch 18 ----> loss_train:0.610648512840271, accuracy_train:60.0\n","Test epoch 18 ----> loss_test:0.36716607213020325, accuracy_test:86.0655737704918\n","Train epoch 19 ----> loss_train:0.23094287514686584, accuracy_train:93.33333333333333\n","Test epoch 19 ----> loss_test:0.3503781259059906, accuracy_test:83.60655737704919\n","Train epoch 20 ----> loss_train:0.2411852777004242, accuracy_train:86.66666666666667\n","Test epoch 20 ----> loss_test:0.3412906527519226, accuracy_test:86.88524590163934\n","Train epoch 21 ----> loss_train:0.5193917751312256, accuracy_train:73.33333333333333\n","Test epoch 21 ----> loss_test:0.3138628304004669, accuracy_test:88.11475409836066\n","Train epoch 22 ----> loss_train:0.19448810815811157, accuracy_train:93.33333333333333\n","Test epoch 22 ----> loss_test:0.3313533365726471, accuracy_test:86.88524590163934\n","Train epoch 23 ----> loss_train:0.414404034614563, accuracy_train:73.33333333333333\n","Test epoch 23 ----> loss_test:0.3215879797935486, accuracy_test:86.0655737704918\n","Train epoch 24 ----> loss_train:0.3939465880393982, accuracy_train:73.33333333333333\n","Test epoch 24 ----> loss_test:0.5650485157966614, accuracy_test:69.26229508196722\n","Train epoch 25 ----> loss_train:0.3674493730068207, accuracy_train:80.0\n","Test epoch 25 ----> loss_test:0.46158814430236816, accuracy_test:77.04918032786885\n","Train epoch 26 ----> loss_train:0.21139144897460938, accuracy_train:100.0\n","Test epoch 26 ----> loss_test:0.3762306571006775, accuracy_test:87.70491803278688\n","Train epoch 27 ----> loss_train:0.24360625445842743, accuracy_train:86.66666666666667\n","Test epoch 27 ----> loss_test:0.34165677428245544, accuracy_test:84.42622950819673\n","Train epoch 28 ----> loss_train:0.314240962266922, accuracy_train:93.33333333333333\n","Test epoch 28 ----> loss_test:0.493245005607605, accuracy_test:78.68852459016394\n","Train epoch 29 ----> loss_train:0.48229771852493286, accuracy_train:66.66666666666666\n","Test epoch 29 ----> loss_test:0.3717788755893707, accuracy_test:81.9672131147541\n","Train epoch 30 ----> loss_train:0.3252904713153839, accuracy_train:86.66666666666667\n","Test epoch 30 ----> loss_test:0.35426631569862366, accuracy_test:84.42622950819673\n","Train epoch 31 ----> loss_train:0.30262789130210876, accuracy_train:86.66666666666667\n","Test epoch 31 ----> loss_test:0.351441353559494, accuracy_test:83.60655737704919\n","Train epoch 32 ----> loss_train:0.21227210760116577, accuracy_train:93.33333333333333\n","Test epoch 32 ----> loss_test:0.3321661353111267, accuracy_test:85.65573770491804\n","Train epoch 33 ----> loss_train:0.2600800693035126, accuracy_train:86.66666666666667\n","Test epoch 33 ----> loss_test:0.3565691113471985, accuracy_test:84.01639344262296\n","Train epoch 34 ----> loss_train:0.16956046223640442, accuracy_train:93.33333333333333\n","Test epoch 34 ----> loss_test:0.3794771730899811, accuracy_test:83.60655737704919\n","Train epoch 35 ----> loss_train:0.2076396346092224, accuracy_train:86.66666666666667\n","Test epoch 35 ----> loss_test:0.26697811484336853, accuracy_test:85.65573770491804\n","Train epoch 36 ----> loss_train:0.1492958962917328, accuracy_train:93.33333333333333\n","Test epoch 36 ----> loss_test:0.2497561275959015, accuracy_test:89.75409836065575\n","Train epoch 37 ----> loss_train:0.09298346936702728, accuracy_train:100.0\n","Test epoch 37 ----> loss_test:0.2428228259086609, accuracy_test:88.11475409836066\n","Train epoch 38 ----> loss_train:0.03686772659420967, accuracy_train:100.0\n","Test epoch 38 ----> loss_test:0.2709486782550812, accuracy_test:86.47540983606558\n","Train epoch 39 ----> loss_train:0.04874860867857933, accuracy_train:100.0\n","Test epoch 39 ----> loss_test:0.3264549970626831, accuracy_test:86.0655737704918\n","Train epoch 40 ----> loss_train:0.052517570555210114, accuracy_train:100.0\n","Test epoch 40 ----> loss_test:0.26367226243019104, accuracy_test:88.52459016393442\n","Train epoch 41 ----> loss_train:0.18150322139263153, accuracy_train:93.33333333333333\n","Test epoch 41 ----> loss_test:0.22230619192123413, accuracy_test:90.57377049180327\n","Train epoch 42 ----> loss_train:0.06426363438367844, accuracy_train:93.33333333333333\n","Test epoch 42 ----> loss_test:0.21696746349334717, accuracy_test:91.39344262295081\n","Train epoch 43 ----> loss_train:0.1148812472820282, accuracy_train:93.33333333333333\n","Test epoch 43 ----> loss_test:0.24720720946788788, accuracy_test:90.1639344262295\n","Train epoch 44 ----> loss_train:0.2639178931713104, accuracy_train:93.33333333333333\n","Test epoch 44 ----> loss_test:0.3643527925014496, accuracy_test:84.8360655737705\n","Train epoch 45 ----> loss_train:0.15969489514827728, accuracy_train:93.33333333333333\n","Test epoch 45 ----> loss_test:0.3031688630580902, accuracy_test:88.9344262295082\n","Train epoch 46 ----> loss_train:0.01594209112226963, accuracy_train:100.0\n","Test epoch 46 ----> loss_test:0.23534666001796722, accuracy_test:91.80327868852459\n","Train epoch 47 ----> loss_train:0.14267171919345856, accuracy_train:93.33333333333333\n","Test epoch 47 ----> loss_test:0.21389374136924744, accuracy_test:92.62295081967213\n","Train epoch 48 ----> loss_train:0.05601683631539345, accuracy_train:100.0\n","Test epoch 48 ----> loss_test:0.20240157842636108, accuracy_test:91.80327868852459\n","Train epoch 49 ----> loss_train:0.08283224701881409, accuracy_train:93.33333333333333\n","Test epoch 49 ----> loss_test:0.22367405891418457, accuracy_test:92.62295081967213\n","Train epoch 50 ----> loss_train:0.4089038074016571, accuracy_train:86.66666666666667\n","Test epoch 50 ----> loss_test:0.23002968728542328, accuracy_test:88.9344262295082\n","Train epoch 51 ----> loss_train:0.04266216233372688, accuracy_train:100.0\n","Test epoch 51 ----> loss_test:0.268976628780365, accuracy_test:86.47540983606558\n","Train epoch 52 ----> loss_train:0.016194330528378487, accuracy_train:100.0\n","Test epoch 52 ----> loss_test:0.20723119378089905, accuracy_test:92.62295081967213\n","Train epoch 53 ----> loss_train:0.1983640193939209, accuracy_train:86.66666666666667\n","Test epoch 53 ----> loss_test:0.2365255355834961, accuracy_test:90.98360655737704\n","Train epoch 54 ----> loss_train:0.052192993462085724, accuracy_train:100.0\n","Test epoch 54 ----> loss_test:0.22560298442840576, accuracy_test:90.98360655737704\n","Train epoch 55 ----> loss_train:0.018273767083883286, accuracy_train:100.0\n","Test epoch 55 ----> loss_test:0.300938218832016, accuracy_test:89.34426229508196\n","Train epoch 56 ----> loss_train:0.029254388064146042, accuracy_train:100.0\n","Test epoch 56 ----> loss_test:0.21089258790016174, accuracy_test:93.0327868852459\n","Train epoch 57 ----> loss_train:0.003264013212174177, accuracy_train:100.0\n","Test epoch 57 ----> loss_test:0.20047633349895477, accuracy_test:92.21311475409836\n","Train epoch 58 ----> loss_train:0.11391778290271759, accuracy_train:93.33333333333333\n","Test epoch 58 ----> loss_test:0.18879885971546173, accuracy_test:94.67213114754098\n","Train epoch 59 ----> loss_train:0.057573847472667694, accuracy_train:93.33333333333333\n","Test epoch 59 ----> loss_test:0.4323604106903076, accuracy_test:90.98360655737704\n","Train epoch 60 ----> loss_train:0.16537858545780182, accuracy_train:93.33333333333333\n","Test epoch 60 ----> loss_test:0.3042074143886566, accuracy_test:91.39344262295081\n","Train epoch 61 ----> loss_train:0.029696408659219742, accuracy_train:100.0\n","Test epoch 61 ----> loss_test:0.17089226841926575, accuracy_test:93.44262295081968\n","Train epoch 62 ----> loss_train:0.09398171305656433, accuracy_train:93.33333333333333\n","Test epoch 62 ----> loss_test:0.2242358773946762, accuracy_test:93.0327868852459\n","Train epoch 63 ----> loss_train:0.04352137818932533, accuracy_train:100.0\n","Test epoch 63 ----> loss_test:0.3910592496395111, accuracy_test:92.21311475409836\n","Train epoch 64 ----> loss_train:0.09287381917238235, accuracy_train:93.33333333333333\n","Test epoch 64 ----> loss_test:0.27094778418540955, accuracy_test:91.80327868852459\n","Train epoch 65 ----> loss_train:0.10542603582143784, accuracy_train:93.33333333333333\n","Test epoch 65 ----> loss_test:0.2800958454608917, accuracy_test:93.85245901639344\n","Train epoch 66 ----> loss_train:0.5572029948234558, accuracy_train:80.0\n","Test epoch 66 ----> loss_test:0.432410329580307, accuracy_test:87.70491803278688\n","Train epoch 67 ----> loss_train:0.06798195093870163, accuracy_train:100.0\n","Test epoch 67 ----> loss_test:0.3009948134422302, accuracy_test:88.11475409836066\n","Train epoch 68 ----> loss_train:0.015737658366560936, accuracy_train:100.0\n","Test epoch 68 ----> loss_test:0.22030843794345856, accuracy_test:91.39344262295081\n","Train epoch 69 ----> loss_train:0.022785933688282967, accuracy_train:100.0\n","Test epoch 69 ----> loss_test:0.17838546633720398, accuracy_test:94.67213114754098\n","Train epoch 70 ----> loss_train:0.011274651624262333, accuracy_train:100.0\n","Test epoch 70 ----> loss_test:0.1593770980834961, accuracy_test:95.08196721311475\n","Train epoch 71 ----> loss_train:0.0003209850692655891, accuracy_train:100.0\n","Test epoch 71 ----> loss_test:0.17241452634334564, accuracy_test:95.08196721311475\n","Train epoch 72 ----> loss_train:0.003324038116261363, accuracy_train:100.0\n","Test epoch 72 ----> loss_test:0.18065917491912842, accuracy_test:93.85245901639344\n","Train epoch 73 ----> loss_train:0.002590508433058858, accuracy_train:100.0\n","Test epoch 73 ----> loss_test:0.14827600121498108, accuracy_test:95.49180327868852\n","Train epoch 74 ----> loss_train:0.0017337917815893888, accuracy_train:100.0\n","Test epoch 74 ----> loss_test:0.19378729164600372, accuracy_test:95.08196721311475\n","Train epoch 75 ----> loss_train:0.004031042102724314, accuracy_train:100.0\n","Test epoch 75 ----> loss_test:0.18208518624305725, accuracy_test:95.08196721311475\n","Train epoch 76 ----> loss_train:0.0017815598985180259, accuracy_train:100.0\n","Test epoch 76 ----> loss_test:0.18251711130142212, accuracy_test:95.08196721311475\n","Train epoch 77 ----> loss_train:0.008133621886372566, accuracy_train:100.0\n","Test epoch 77 ----> loss_test:0.17510253190994263, accuracy_test:95.90163934426229\n","Train epoch 78 ----> loss_train:0.0011927328305318952, accuracy_train:100.0\n","Test epoch 78 ----> loss_test:0.17269931733608246, accuracy_test:94.67213114754098\n","Train epoch 79 ----> loss_train:0.00044058248749934137, accuracy_train:100.0\n","Test epoch 79 ----> loss_test:0.17271451652050018, accuracy_test:95.08196721311475\n","Train epoch 80 ----> loss_train:0.0008920478285290301, accuracy_train:100.0\n","Test epoch 80 ----> loss_test:0.1864911913871765, accuracy_test:94.26229508196722\n","Train epoch 81 ----> loss_train:0.000158652663230896, accuracy_train:100.0\n","Test epoch 81 ----> loss_test:0.17999692261219025, accuracy_test:94.67213114754098\n","Train epoch 82 ----> loss_train:0.00014910721802152693, accuracy_train:100.0\n","Test epoch 82 ----> loss_test:0.19288823008537292, accuracy_test:94.67213114754098\n","Train epoch 83 ----> loss_train:0.005794026888906956, accuracy_train:100.0\n","Test epoch 83 ----> loss_test:0.2012569010257721, accuracy_test:94.67213114754098\n","Train epoch 84 ----> loss_train:9.496918210061267e-05, accuracy_train:100.0\n","Test epoch 84 ----> loss_test:0.20834201574325562, accuracy_test:93.85245901639344\n","Train epoch 85 ----> loss_train:0.0023733123671263456, accuracy_train:100.0\n","Test epoch 85 ----> loss_test:0.21430033445358276, accuracy_test:93.85245901639344\n","Train epoch 86 ----> loss_train:4.4585376599570736e-05, accuracy_train:100.0\n","Test epoch 86 ----> loss_test:0.2628557085990906, accuracy_test:93.85245901639344\n","Train epoch 87 ----> loss_train:8.006035204743966e-05, accuracy_train:100.0\n","Test epoch 87 ----> loss_test:0.20230112969875336, accuracy_test:93.85245901639344\n","Train epoch 88 ----> loss_train:0.00029515137430280447, accuracy_train:100.0\n","Test epoch 88 ----> loss_test:0.1977137178182602, accuracy_test:94.26229508196722\n","Train epoch 89 ----> loss_train:0.0018616248853504658, accuracy_train:100.0\n","Test epoch 89 ----> loss_test:0.224327951669693, accuracy_test:94.67213114754098\n","Train epoch 90 ----> loss_train:0.00030044352752156556, accuracy_train:100.0\n","Test epoch 90 ----> loss_test:0.19583286345005035, accuracy_test:93.85245901639344\n","Train epoch 91 ----> loss_train:0.00015955293201841414, accuracy_train:100.0\n","Test epoch 91 ----> loss_test:0.21602772176265717, accuracy_test:94.26229508196722\n","Train epoch 92 ----> loss_train:0.001483868807554245, accuracy_train:100.0\n","Test epoch 92 ----> loss_test:0.24502329528331757, accuracy_test:93.85245901639344\n","Train epoch 93 ----> loss_train:0.00026790419360622764, accuracy_train:100.0\n","Test epoch 93 ----> loss_test:0.23526188731193542, accuracy_test:93.44262295081968\n","Train epoch 94 ----> loss_train:2.0343308278825134e-05, accuracy_train:100.0\n","Test epoch 94 ----> loss_test:0.2672439217567444, accuracy_test:93.85245901639344\n","Train epoch 95 ----> loss_train:0.00010300715075572953, accuracy_train:100.0\n","Test epoch 95 ----> loss_test:0.2671581506729126, accuracy_test:93.44262295081968\n","Train epoch 96 ----> loss_train:0.00018610352708492428, accuracy_train:100.0\n","Test epoch 96 ----> loss_test:0.2703568935394287, accuracy_test:93.85245901639344\n","Train epoch 97 ----> loss_train:4.7753997932886705e-05, accuracy_train:100.0\n","Test epoch 97 ----> loss_test:0.2562352120876312, accuracy_test:93.44262295081968\n","Train epoch 98 ----> loss_train:1.0251977755615371e-06, accuracy_train:100.0\n","Test epoch 98 ----> loss_test:0.2602320909500122, accuracy_test:93.44262295081968\n","Train epoch 99 ----> loss_train:1.0649290516084875e-06, accuracy_train:100.0\n","Test epoch 99 ----> loss_test:0.2653295397758484, accuracy_test:93.85245901639344\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Pax3tx6kwllE","colab":{"base_uri":"https://localhost:8080/","height":675},"executionInfo":{"status":"ok","timestamp":1630016604092,"user_tz":-120,"elapsed":825,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"c3f5b8e2-caa8-4fae-90a9-595092b6a3bf"},"source":["plt.figure(figsize=[15,10])\n","plt.subplot(2, 2, 1)\n","plt.ylim([0,1])\n","plt.plot(train_loss_lst)\n","plt.title(\"Training Loss\")\n","plt.subplot(2, 2, 2)\n","plt.ylim([0,1])\n","plt.plot(test_loss_lst)\n","plt.title(\"Test Loss\")\n","plt.subplot(2, 2, 3)\n","plt.ylim([0.5,1])\n","plt.plot(train_acc_lst)\n","plt.title(\"Training Accuracy\")\n","plt.subplot(2, 2, 4)\n","plt.ylim([0.5,1])\n","plt.plot(test_acc_lst)\n","plt.title(\"Test Accuracy\")\n","#plt.savefig(out_dir + str(labels_plot) + \"performance_plot.pdf\")\n","plt.show()"],"execution_count":59,"outputs":[{"output_type":"execute_result","data":{"text/plain":["
"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["(0.0, 1.0)"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["[]"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["Text(0.5, 1.0, 'Training Loss')"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["(0.0, 1.0)"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["[]"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["Text(0.5, 1.0, 'Test Loss')"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["(0.5, 1.0)"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["[]"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["Text(0.5, 1.0, 'Training Accuracy')"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["(0.5, 1.0)"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["[]"]},"metadata":{},"execution_count":59},{"output_type":"execute_result","data":{"text/plain":["Text(0.5, 1.0, 'Test Accuracy')"]},"metadata":{},"execution_count":59},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"id":"LwJ3_N0mc8rS","executionInfo":{"status":"ok","timestamp":1630016604094,"user_tz":-120,"elapsed":17,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":[""],"execution_count":59,"outputs":[]},{"cell_type":"code","metadata":{"id":"9mdYmfnpwllF","executionInfo":{"status":"ok","timestamp":1630016604095,"user_tz":-120,"elapsed":15,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}}},"source":["from captum.attr import IntegratedGradients\n","\n","features_name = []\n","for col in df.columns:\n"," if col.split('.')[1]=='x':\n"," features_name.append(col)\n","for col in df.columns:\n"," if col.split('.')[1]=='y':\n"," features_name.append(col)\n","\n","ig = IntegratedGradients(Classifier)\n","\n","test_loader = DataLoader(test_data, batch_size=100)\n","x, y, q, target=next(iter(test_loader))\n","x, y, q = x.to(device), y.to(device), q.to(device)\n","x_encoded = onehot_encode(x, q)\n","y_encoded = onehot_encode(y, q)\n","test_input_tensor = torch.cat((x_encoded, y_encoded), dim=1).float()\n","\n","nq = 3\n","onehot_quality = F.one_hot(q.view(1,-1), nq).view(q.size(0), q.size(1), q.size(2), nq).permute(0, 1, 3, 2)\n"],"execution_count":60,"outputs":[]},{"cell_type":"code","metadata":{"id":"3RRuJ0trwllF","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1630016605538,"user_tz":-120,"elapsed":1456,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"577dcd6c-77f8-4476-e95a-a1d2abb1a457"},"source":["importance_data = []\n","torch.backends.cudnn.enabled=False\n","test_input_tensor.requires_grad = True\n","\n","for t in range(2):\n"," attr, delta = ig.attribute(test_input_tensor, target=t, return_convergence_delta=True)\n"," print(attr.shape)\n"," attr_x = attr[:, :45, :]\n"," attr_y = attr[:, 45:, :]\n"," print(attr_x.shape)\n"," attr_x = torch.reshape(attr_x, [attr_x.size(0), 15, 3, 100])\n"," print(attr_x.shape)\n"," print(attr_y.shape)\n"," attr_y = torch.reshape(attr_y, [attr_x.size(0), 15, 3, 100])\n"," print(attr_y.shape)\n"," output_x = (torch.mul(attr_x, onehot_quality.float()).sum(dim=2)).squeeze()\n"," output_y = (torch.mul(attr_y, onehot_quality.float()).sum(dim=2)).squeeze()\n"," print(output_x.shape)\n"," print(output_y.shape)\n"," importances = []\n"," importances.extend(output_x.detach().cpu().mean(axis=2).mean(axis=0).numpy())\n"," importances.extend(output_y.detach().cpu().mean(axis=2).mean(axis=0).numpy())\n"," importance_data.append(importances)"],"execution_count":61,"outputs":[{"output_type":"stream","text":["torch.Size([100, 90, 100])\n","torch.Size([100, 45, 100])\n","torch.Size([100, 15, 3, 100])\n","torch.Size([100, 45, 100])\n","torch.Size([100, 15, 3, 100])\n","torch.Size([100, 15, 100])\n","torch.Size([100, 15, 100])\n","torch.Size([100, 90, 100])\n","torch.Size([100, 45, 100])\n","torch.Size([100, 15, 3, 100])\n","torch.Size([100, 45, 100])\n","torch.Size([100, 15, 3, 100])\n","torch.Size([100, 15, 100])\n","torch.Size([100, 15, 100])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"U4o1Z1_swllF","colab":{"base_uri":"https://localhost:8080/","height":386},"executionInfo":{"status":"ok","timestamp":1630016606472,"user_tz":-120,"elapsed":937,"user":{"displayName":"Franziska Ziolkowski","photoUrl":"","userId":"15466942233857614154"}},"outputId":"f02e975b-41b7-485f-dc1f-fbf203300867"},"source":["importance_df = pd.DataFrame(importance_data, columns = features_name).T\n","importance_df.columns = labels_plot\n","#importance_df.to_csv(out_dir + str(labels_plot) + \"importance_score.csv\")\n","importance_df.plot.bar(figsize=(15,5))\n","plt.legend(loc=\"upper left\")\n","plt.ylabel(\"importance score\")\n","plt.tight_layout()\n","#plt.savefig(out_dir + str(labels_plot) + \"importance_plot.pdf\")\n","plt.show()"],"execution_count":62,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":62},{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":62},{"output_type":"execute_result","data":{"text/plain":["Text(0, 0.5, 'importance score')"]},"metadata":{},"execution_count":62},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]}]}