{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "271f7630", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from IPython.core.interactiveshell import InteractiveShell\n", "import glob\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import plot_confusion_matrix\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_selection import SelectFromModel\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.metrics import f1_score\n", "from sklearn.metrics import balanced_accuracy_score\n", "%matplotlib inline\n", "plt.rcParams['figure.dpi'] = 100 # adjust fig size in notebook\n", "InteractiveShell.ast_node_interactivity = \"all\" # allows for multiple outputs per cell to be shown in notebook" ] }, { "cell_type": "markdown", "id": "4171d5c8", "metadata": {}, "source": [ "## Data Loading and Subsetting" ] }, { "cell_type": "code", "execution_count": 2, "id": "b5a0d9de", "metadata": {}, "outputs": [], "source": [ "# load feature data\n", "df = pd.read_csv(\"C:/Users/franz/Documents/Bachelor Studium/BachelorArbeit/JupyterLabBA/files/DLCfiles_final_results/res_file.csv\")\n", "df = df[df[\"distance\"] != \"not_specified\"]\n", "\n", "# subsetting data\n", "MOPstroke_BL = df[df[\"group\"] == \"MOPstroke_BL\"].copy()\n", "MOPstroke_P3 = df[df[\"group\"] == \"MOPstroke_P3\"].copy()\n", "MOPstroke_P28 = df[df[\"group\"] == \"MOPstroke_P28\"].copy()\n", "MOSstroke_BL = df[df[\"group\"] == \"MOSstroke_BL\"].copy()\n", "MOSstroke_P3 = df[df[\"group\"] == \"MOSstroke_P3\"].copy()\n", "MOSstroke_P28 = df[df[\"group\"] == \"MOSstroke_P28\"].copy()\n", "WMstroke_BL = df[df[\"group\"] == \"WMstroke_BL\"].copy()\n", "WMstroke_P3 = df[df[\"group\"] == \"WMstroke_P3\"].copy()\n", "WMstroke_P28 = df[df[\"group\"] == \"WMstroke_P28\"].copy()\n", "MOPMOSsham_BL = df[df[\"group\"] == \"MOPMOSsham_BL\"].copy()\n", "MOPMOSsham_P3 = df[df[\"group\"] == \"MOPMOSsham_P3\"].copy()\n", "MOPMOSsham_P28 = df[df[\"group\"] == \"MOPMOSsham_P28\"].copy()\n", "WMsham_BL = df[df[\"group\"] == \"WMsham_BL\"].copy()\n", "WMsham_P3 = df[df[\"group\"] == \"WMsham_P3\"].copy()\n", "WMsham_P28 = df[df[\"group\"] == \"WMsham_P28\"].copy()" ] }, { "cell_type": "code", "execution_count": 3, "id": "998a31e1", "metadata": {}, "outputs": [], "source": [ "feature_names = [\n", " \"time\", \"distance\", \"average_speed\",\n", " \"fp_cycle_dur_median\", \"fp_stance_dur_median\", \"fp_swing_dur_median\", \"fp_sw_st_ratio_median\", \"fp_stride_len_median\",\n", " \"hp_cycle_dur_median\", \"hp_stance_dur_median\", \"hp_swing_dur_median\", \"hp_sw_st_ratio_median\", \"hp_stride_len_median\",\n", " \"paw_dist_min\", \"paw_dist_max\", \"paw_dist_mean\", \"paw_dist_std\",\n", " \"fp_vel_min\", \"fp_vel_max\", \"fp_vel_mean\", \"fp_vel_std\", \"hp_vel_min\", \"hp_vel_max\", \"hp_vel_mean\", \"hp_vel_std\",\n", " \"tilt_up_min\", \"tilt_up_max\", \"tilt_up_mean\", \"tilt_up_std\", \"tilt_low_min\", \"tilt_low_max\", \"tilt_low_mean\", \"tilt_low_std\",\n", " \"snout_h_min\", \"snout_h_max\", \"snout_h_mean\", \"snout_h_std\",\n", " \"midb_h_min\", \"midb_h_max\", \"midb_h_mean\", \"midb_h_std\",\n", " \"no_paw_contact\", \"single_paw_contact\", \"double_paw_contact\",\n", " \"ank_h_min\", \"ank_h_max\", \"ank_h_mean\", \"ank_h_median\", \"ank_h_std\", \"ank_ang_min\", \"ank_ang_max\", \"ank_ang_mean\", \"ank_ang_median\", \"ank_ang_ROM\", \"ank_ang_std\", \"ank_angv_min\", \"ank_angv_max\", \"ank_angv_mean\", \"ank_angv_std\", \"ank_pos_min\", \"ank_pos_max\", \"ank_pos_mean\", \"ank_pos_std\", \"ank_anginit_min\", \"ank_anginit_max\", \"ank_anginit_mean\", \"ank_anginit_median\", \"ank_anginit_std\", \"ank_angpsw_min\", \"ank_angpsw_max\", \"ank_angpsw_mean\", \"ank_angpsw_median\", \"ank_angpsw_std\",\n", " \"kn_h_min\", \"kn_h_max\", \"kn_h_mean\", \"kn_h_median\", \"kn_h_std\", \"kn_ang_min\", \"kn_ang_max\", \"kn_ang_mean\", \"kn_ang_median\", \"kn_ang_ROM\", \"kn_ang_std\", \"kn_angv_min\", \"kn_angv_max\", \"kn_angv_mean\", \"kn_angv_std\", \"kn_pos_min\", \"kn_pos_max\", \"kn_pos_mean\", \"kn_pos_std\", \"kn_anginit_min\", \"kn_anginit_max\", \"kn_anginit_mean\", \"kn_anginit_median\", \"kn_anginit_std\", \"kn_angpsw_min\", \"kn_angpsw_max\", \"kn_angpsw_mean\", \"kn_angpsw_median\", \"kn_angpsw_std\",\n", " \"hip_h_min\", \"hip_h_max\", \"hip_h_mean\", \"hip_h_median\", \"hip_h_std\", \"hip_ang_min\", \"hip_ang_max\", \"hip_ang_mean\", \"hip_ang_median\", \"hip_ang_ROM\", \"hip_ang_std\", \"hip_angv_min\", \"hip_angv_max\", \"hip_angv_mean\", \"hip_angv_std\", \"hip_pos_min\", \"hip_pos_max\", \"hip_pos_mean\", \"hip_pos_std\", \"hip_anginit_min\", \"hip_anginit_max\", \"hip_anginit_mean\", \"hip_anginit_median\", \"hip_anginit_std\", \"hip_angpsw_min\", \"hip_angpsw_max\", \"hip_angpsw_mean\", \"hip_angpsw_median\", \"hip_angpsw_std\",\n", " \"wr_h_min\", \"wr_h_max\", \"wr_h_mean\", \"wr_h_median\", \"wr_h_std\", \"wr_ang_min\", \"wr_ang_max\", \"wr_ang_mean\", \"wr_ang_median\", \"wr_ang_ROM\", \"wr_ang_std\", \"wr_angv_min\", \"wr_angv_max\", \"wr_angv_mean\", \"wr_angv_std\", \"wr_pos_min\", \"wr_pos_max\", \"wr_pos_mean\", \"wr_pos_std\", \"wr_anginit_min\", \"wr_anginit_max\", \"wr_anginit_mean\", \"wr_anginit_median\", \"wr_anginit_std\", \"wr_angpsw_min\", \"wr_angpsw_max\", \"wr_angpsw_mean\", \"wr_angpsw_median\", \"wr_angpsw_std\",\n", " \"el_h_min\", \"el_h_max\", \"el_h_mean\", \"el_h_median\", \"el_h_std\", \"el_ang_min\", \"el_ang_max\", \"el_ang_mean\", \"el_ang_median\", \"el_ang_ROM\", \"el_ang_std\", \"el_angv_min\", \"el_angv_max\", \"el_angv_mean\", \"el_angv_std\", \"el_pos_min\", \"el_pos_max\", \"el_pos_mean\", \"el_pos_std\", \"el_anginit_min\", \"el_anginit_max\", \"el_anginit_mean\", \"el_anginit_median\", \"el_anginit_std\", \"el_angpsw_min\", \"el_angpsw_max\", \"el_angpsw_mean\", \"el_angpsw_median\", \"el_angpsw_std\",\n", " \"sh_h_min\", \"sh_h_max\", \"sh_h_mean\", \"sh_h_median\", \"sh_h_std\", \"sh_ang_min\", \"sh_ang_max\", \"sh_ang_mean\", \"sh_ang_median\", \"sh_ang_ROM\", \"sh_ang_std\", \"sh_angv_min\", \"sh_angv_max\", \"sh_angv_mean\", \"sh_angv_std\", \"sh_pos_min\", \"sh_pos_max\", \"sh_pos_mean\", \"sh_pos_std\", \"sh_anginit_min\", \"sh_anginit_max\", \"sh_anginit_mean\", \"sh_anginit_median\", \"sh_anginit_std\", \"sh_angpsw_min\", \"sh_angpsw_max\", \"sh_angpsw_mean\", \"sh_angpsw_median\", \"sh_angpsw_std\",\n", " \"hdrop_num_abs\", \"hdrop_num_rel\",\n", " \"under_beam\"]" ] }, { "cell_type": "code", "execution_count": 4, "id": "2d7b493d", "metadata": {}, "outputs": [], "source": [ "# name two or three files to compare with classifier\n", "df1 = MOPstroke_P3\n", "df2 = MOPMOSsham_P3\n", "df3 = pd.DataFrame()" ] }, { "cell_type": "code", "execution_count": 5, "id": "442cd87f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | time | \n", "fp_cycle_dur_median | \n", "fp_stance_dur_median | \n", "fp_swing_dur_median | \n", "fp_sw_st_ratio_median | \n", "fp_stride_len_median | \n", "hp_cycle_dur_median | \n", "hp_stance_dur_median | \n", "hp_swing_dur_median | \n", "hp_sw_st_ratio_median | \n", "... | \n", "sh_anginit_mean | \n", "sh_anginit_median | \n", "sh_anginit_std | \n", "sh_angpsw_min | \n", "sh_angpsw_max | \n", "sh_angpsw_mean | \n", "sh_angpsw_median | \n", "sh_angpsw_std | \n", "hdrop_num_abs | \n", "under_beam | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "100.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "... | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "101.000000 | \n", "
mean | \n", "8.569460 | \n", "0.412788 | \n", "0.069316 | \n", "0.330655 | \n", "4.897041 | \n", "0.160896 | \n", "0.482551 | \n", "0.063783 | \n", "0.418767 | \n", "9.509027 | \n", "... | \n", "220.932931 | \n", "222.140768 | \n", "17.251572 | \n", "189.510422 | \n", "246.704918 | \n", "219.029805 | \n", "220.039472 | \n", "16.216653 | \n", "1.811881 | \n", "0.051994 | \n", "
std | \n", "5.172731 | \n", "0.254325 | \n", "0.015501 | \n", "0.247497 | \n", "3.485659 | \n", "0.163586 | \n", "0.452034 | \n", "0.013200 | \n", "0.446884 | \n", "11.704700 | \n", "... | \n", "8.888797 | \n", "9.552478 | \n", "4.921159 | \n", "10.224294 | \n", "9.745654 | \n", "8.049792 | \n", "8.641654 | \n", "3.751567 | \n", "2.504847 | \n", "0.081752 | \n", "
min | \n", "3.765273 | \n", "0.193091 | \n", "0.045137 | \n", "0.128727 | \n", "1.666667 | \n", "-0.019821 | \n", "0.221667 | \n", "0.038031 | \n", "0.152614 | \n", "3.267857 | \n", "... | \n", "195.964708 | \n", "196.871671 | \n", "8.015501 | \n", "162.994557 | \n", "222.802952 | \n", "201.183966 | \n", "196.708327 | \n", "7.801642 | \n", "0.000000 | \n", "0.000000 | \n", "
25% | \n", "5.627632 | \n", "0.271453 | \n", "0.064231 | \n", "0.196953 | \n", "3.000000 | \n", "0.044324 | \n", "0.294882 | \n", "0.054167 | \n", "0.237876 | \n", "5.563866 | \n", "... | \n", "216.375162 | \n", "215.788337 | \n", "13.338026 | \n", "182.895660 | \n", "241.514277 | \n", "213.950086 | \n", "216.237298 | \n", "13.178168 | \n", "0.000000 | \n", "0.000000 | \n", "
50% | \n", "7.350081 | \n", "0.333333 | \n", "0.066667 | \n", "0.250000 | \n", "4.000000 | \n", "0.131568 | \n", "0.359567 | \n", "0.061905 | \n", "0.299535 | \n", "6.957738 | \n", "... | \n", "221.016359 | \n", "223.754013 | \n", "17.035766 | \n", "186.846959 | \n", "246.307548 | \n", "219.075362 | \n", "220.463486 | \n", "16.641529 | \n", "1.000000 | \n", "0.017544 | \n", "
75% | \n", "9.033333 | \n", "0.433333 | \n", "0.069105 | \n", "0.357765 | \n", "5.500000 | \n", "0.220495 | \n", "0.484211 | \n", "0.072094 | \n", "0.421053 | \n", "9.146875 | \n", "... | \n", "226.354558 | \n", "228.801225 | \n", "20.925703 | \n", "193.564390 | \n", "252.753739 | \n", "224.454297 | \n", "225.649237 | \n", "19.646506 | \n", "3.000000 | \n", "0.055172 | \n", "
max | \n", "42.509025 | \n", "1.631500 | \n", "0.133333 | \n", "1.485228 | \n", "27.250000 | \n", "1.083275 | \n", "3.906600 | \n", "0.101502 | \n", "3.807584 | \n", "103.983810 | \n", "... | \n", "239.566316 | \n", "240.441019 | \n", "33.949624 | \n", "220.114148 | \n", "268.970073 | \n", "240.642112 | \n", "241.787432 | \n", "25.013106 | \n", "13.000000 | \n", "0.410188 | \n", "
8 rows × 218 columns
\n", "\n", " | time | \n", "fp_cycle_dur_median | \n", "fp_stance_dur_median | \n", "fp_swing_dur_median | \n", "fp_sw_st_ratio_median | \n", "fp_stride_len_median | \n", "hp_cycle_dur_median | \n", "hp_stance_dur_median | \n", "hp_swing_dur_median | \n", "hp_sw_st_ratio_median | \n", "... | \n", "sh_anginit_mean | \n", "sh_anginit_median | \n", "sh_anginit_std | \n", "sh_angpsw_min | \n", "sh_angpsw_max | \n", "sh_angpsw_mean | \n", "sh_angpsw_median | \n", "sh_angpsw_std | \n", "hdrop_num_abs | \n", "under_beam | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "... | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "83.000000 | \n", "
mean | \n", "6.222911 | \n", "0.309528 | \n", "0.072473 | \n", "0.225063 | \n", "3.370152 | \n", "0.193170 | \n", "0.316402 | \n", "0.058356 | \n", "0.258046 | \n", "5.898642 | \n", "... | \n", "223.165525 | \n", "223.815682 | \n", "14.491601 | \n", "195.040818 | \n", "245.178300 | \n", "221.084253 | \n", "221.097157 | \n", "14.039303 | \n", "0.710843 | \n", "0.013900 | \n", "
std | \n", "2.099432 | \n", "0.090808 | \n", "0.019437 | \n", "0.093464 | \n", "1.789558 | \n", "0.139702 | \n", "0.091626 | \n", "0.011285 | \n", "0.089563 | \n", "2.288775 | \n", "... | \n", "11.880388 | \n", "12.651813 | \n", "4.031649 | \n", "14.365911 | \n", "11.658140 | \n", "11.930484 | \n", "12.811183 | \n", "3.790966 | \n", "2.471941 | \n", "0.045644 | \n", "
min | \n", "3.375587 | \n", "0.189285 | \n", "0.031935 | \n", "0.092980 | \n", "1.000000 | \n", "-0.007645 | \n", "0.189263 | \n", "0.040351 | \n", "0.135775 | \n", "2.965789 | \n", "... | \n", "197.137017 | \n", "195.638895 | \n", "7.228378 | \n", "168.859958 | \n", "219.366591 | \n", "198.863748 | \n", "199.416563 | \n", "8.699910 | \n", "0.000000 | \n", "0.000000 | \n", "
25% | \n", "4.866214 | \n", "0.249921 | \n", "0.064049 | \n", "0.163635 | \n", "2.000000 | \n", "0.083155 | \n", "0.250915 | \n", "0.051652 | \n", "0.197152 | \n", "4.304436 | \n", "... | \n", "213.990719 | \n", "214.149501 | \n", "11.676275 | \n", "183.546411 | \n", "237.482189 | \n", "212.143198 | \n", "212.835242 | \n", "10.860017 | \n", "0.000000 | \n", "0.000000 | \n", "
50% | \n", "5.533333 | \n", "0.287411 | \n", "0.066667 | \n", "0.200000 | \n", "3.000000 | \n", "0.166551 | \n", "0.288865 | \n", "0.056164 | \n", "0.233616 | \n", "5.171875 | \n", "... | \n", "222.355070 | \n", "221.843710 | \n", "13.835482 | \n", "189.553503 | \n", "244.588448 | \n", "218.782297 | \n", "218.980809 | \n", "12.956647 | \n", "0.000000 | \n", "0.000000 | \n", "
75% | \n", "6.833590 | \n", "0.356065 | \n", "0.081653 | \n", "0.263161 | \n", "4.500000 | \n", "0.295055 | \n", "0.372115 | \n", "0.063228 | \n", "0.311286 | \n", "7.202478 | \n", "... | \n", "232.538214 | \n", "232.795688 | \n", "16.909589 | \n", "207.006850 | \n", "255.574905 | \n", "231.022127 | \n", "232.224749 | \n", "16.832864 | \n", "0.000000 | \n", "0.000000 | \n", "
max | \n", "11.326717 | \n", "0.550972 | \n", "0.157494 | \n", "0.500000 | \n", "10.000000 | \n", "0.645501 | \n", "0.583193 | \n", "0.099884 | \n", "0.483309 | \n", "13.336508 | \n", "... | \n", "246.930219 | \n", "247.688365 | \n", "24.023483 | \n", "233.153303 | \n", "269.994174 | \n", "248.686375 | \n", "249.724347 | \n", "23.334560 | \n", "15.000000 | \n", "0.284091 | \n", "
8 rows × 218 columns
\n", "