{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "271f7630", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from IPython.core.interactiveshell import InteractiveShell\n", "import glob\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import plot_confusion_matrix\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.feature_selection import SelectFromModel\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.metrics import f1_score\n", "from sklearn.metrics import balanced_accuracy_score\n", "%matplotlib inline\n", "plt.rcParams['figure.dpi'] = 100 # adjust fig size in notebook\n", "InteractiveShell.ast_node_interactivity = \"all\" # allows for multiple outputs per cell to be shown in notebook" ] }, { "cell_type": "markdown", "id": "4171d5c8", "metadata": {}, "source": [ "## Data Loading and Subsetting" ] }, { "cell_type": "code", "execution_count": 2, "id": "b5a0d9de", "metadata": {}, "outputs": [], "source": [ "# load feature data\n", "df = pd.read_csv(\"C:/Users/franz/Documents/Bachelor Studium/BachelorArbeit/JupyterLabBA/files/DLCfiles_final_results/res_file.csv\")\n", "df = df[df[\"distance\"] != \"not_specified\"]\n", "\n", "# subsetting data\n", "MOPstroke_BL = df[df[\"group\"] == \"MOPstroke_BL\"].copy()\n", "MOPstroke_P3 = df[df[\"group\"] == \"MOPstroke_P3\"].copy()\n", "MOPstroke_P28 = df[df[\"group\"] == \"MOPstroke_P28\"].copy()\n", "MOSstroke_BL = df[df[\"group\"] == \"MOSstroke_BL\"].copy()\n", "MOSstroke_P3 = df[df[\"group\"] == \"MOSstroke_P3\"].copy()\n", "MOSstroke_P28 = df[df[\"group\"] == \"MOSstroke_P28\"].copy()\n", "WMstroke_BL = df[df[\"group\"] == \"WMstroke_BL\"].copy()\n", "WMstroke_P3 = df[df[\"group\"] == \"WMstroke_P3\"].copy()\n", "WMstroke_P28 = df[df[\"group\"] == \"WMstroke_P28\"].copy()\n", "MOPMOSsham_BL = df[df[\"group\"] == \"MOPMOSsham_BL\"].copy()\n", "MOPMOSsham_P3 = df[df[\"group\"] == \"MOPMOSsham_P3\"].copy()\n", "MOPMOSsham_P28 = df[df[\"group\"] == \"MOPMOSsham_P28\"].copy()\n", "WMsham_BL = df[df[\"group\"] == \"WMsham_BL\"].copy()\n", "WMsham_P3 = df[df[\"group\"] == \"WMsham_P3\"].copy()\n", "WMsham_P28 = df[df[\"group\"] == \"WMsham_P28\"].copy()" ] }, { "cell_type": "code", "execution_count": 3, "id": "998a31e1", "metadata": {}, "outputs": [], "source": [ "feature_names = [\n", " \"time\", \"distance\", \"average_speed\",\n", " \"fp_cycle_dur_median\", \"fp_stance_dur_median\", \"fp_swing_dur_median\", \"fp_sw_st_ratio_median\", \"fp_stride_len_median\",\n", " \"hp_cycle_dur_median\", \"hp_stance_dur_median\", \"hp_swing_dur_median\", \"hp_sw_st_ratio_median\", \"hp_stride_len_median\",\n", " \"paw_dist_min\", \"paw_dist_max\", \"paw_dist_mean\", \"paw_dist_std\",\n", " \"fp_vel_min\", \"fp_vel_max\", \"fp_vel_mean\", \"fp_vel_std\", \"hp_vel_min\", \"hp_vel_max\", \"hp_vel_mean\", \"hp_vel_std\",\n", " \"tilt_up_min\", \"tilt_up_max\", \"tilt_up_mean\", \"tilt_up_std\", \"tilt_low_min\", \"tilt_low_max\", \"tilt_low_mean\", \"tilt_low_std\",\n", " \"snout_h_min\", \"snout_h_max\", \"snout_h_mean\", \"snout_h_std\",\n", " \"midb_h_min\", \"midb_h_max\", \"midb_h_mean\", \"midb_h_std\",\n", " \"no_paw_contact\", \"single_paw_contact\", \"double_paw_contact\",\n", " \"ank_h_min\", \"ank_h_max\", \"ank_h_mean\", \"ank_h_median\", \"ank_h_std\", \"ank_ang_min\", \"ank_ang_max\", \"ank_ang_mean\", \"ank_ang_median\", \"ank_ang_ROM\", \"ank_ang_std\", \"ank_angv_min\", \"ank_angv_max\", \"ank_angv_mean\", \"ank_angv_std\", \"ank_pos_min\", \"ank_pos_max\", \"ank_pos_mean\", \"ank_pos_std\", \"ank_anginit_min\", \"ank_anginit_max\", \"ank_anginit_mean\", \"ank_anginit_median\", \"ank_anginit_std\", \"ank_angpsw_min\", \"ank_angpsw_max\", \"ank_angpsw_mean\", \"ank_angpsw_median\", \"ank_angpsw_std\",\n", " \"kn_h_min\", \"kn_h_max\", \"kn_h_mean\", \"kn_h_median\", \"kn_h_std\", \"kn_ang_min\", \"kn_ang_max\", \"kn_ang_mean\", \"kn_ang_median\", \"kn_ang_ROM\", \"kn_ang_std\", \"kn_angv_min\", \"kn_angv_max\", \"kn_angv_mean\", \"kn_angv_std\", \"kn_pos_min\", \"kn_pos_max\", \"kn_pos_mean\", \"kn_pos_std\", \"kn_anginit_min\", \"kn_anginit_max\", \"kn_anginit_mean\", \"kn_anginit_median\", \"kn_anginit_std\", \"kn_angpsw_min\", \"kn_angpsw_max\", \"kn_angpsw_mean\", \"kn_angpsw_median\", \"kn_angpsw_std\",\n", " \"hip_h_min\", \"hip_h_max\", \"hip_h_mean\", \"hip_h_median\", \"hip_h_std\", \"hip_ang_min\", \"hip_ang_max\", \"hip_ang_mean\", \"hip_ang_median\", \"hip_ang_ROM\", \"hip_ang_std\", \"hip_angv_min\", \"hip_angv_max\", \"hip_angv_mean\", \"hip_angv_std\", \"hip_pos_min\", \"hip_pos_max\", \"hip_pos_mean\", \"hip_pos_std\", \"hip_anginit_min\", \"hip_anginit_max\", \"hip_anginit_mean\", \"hip_anginit_median\", \"hip_anginit_std\", \"hip_angpsw_min\", \"hip_angpsw_max\", \"hip_angpsw_mean\", \"hip_angpsw_median\", \"hip_angpsw_std\",\n", " \"wr_h_min\", \"wr_h_max\", \"wr_h_mean\", \"wr_h_median\", \"wr_h_std\", \"wr_ang_min\", \"wr_ang_max\", \"wr_ang_mean\", \"wr_ang_median\", \"wr_ang_ROM\", \"wr_ang_std\", \"wr_angv_min\", \"wr_angv_max\", \"wr_angv_mean\", \"wr_angv_std\", \"wr_pos_min\", \"wr_pos_max\", \"wr_pos_mean\", \"wr_pos_std\", \"wr_anginit_min\", \"wr_anginit_max\", \"wr_anginit_mean\", \"wr_anginit_median\", \"wr_anginit_std\", \"wr_angpsw_min\", \"wr_angpsw_max\", \"wr_angpsw_mean\", \"wr_angpsw_median\", \"wr_angpsw_std\",\n", " \"el_h_min\", \"el_h_max\", \"el_h_mean\", \"el_h_median\", \"el_h_std\", \"el_ang_min\", \"el_ang_max\", \"el_ang_mean\", \"el_ang_median\", \"el_ang_ROM\", \"el_ang_std\", \"el_angv_min\", \"el_angv_max\", \"el_angv_mean\", \"el_angv_std\", \"el_pos_min\", \"el_pos_max\", \"el_pos_mean\", \"el_pos_std\", \"el_anginit_min\", \"el_anginit_max\", \"el_anginit_mean\", \"el_anginit_median\", \"el_anginit_std\", \"el_angpsw_min\", \"el_angpsw_max\", \"el_angpsw_mean\", \"el_angpsw_median\", \"el_angpsw_std\",\n", " \"sh_h_min\", \"sh_h_max\", \"sh_h_mean\", \"sh_h_median\", \"sh_h_std\", \"sh_ang_min\", \"sh_ang_max\", \"sh_ang_mean\", \"sh_ang_median\", \"sh_ang_ROM\", \"sh_ang_std\", \"sh_angv_min\", \"sh_angv_max\", \"sh_angv_mean\", \"sh_angv_std\", \"sh_pos_min\", \"sh_pos_max\", \"sh_pos_mean\", \"sh_pos_std\", \"sh_anginit_min\", \"sh_anginit_max\", \"sh_anginit_mean\", \"sh_anginit_median\", \"sh_anginit_std\", \"sh_angpsw_min\", \"sh_angpsw_max\", \"sh_angpsw_mean\", \"sh_angpsw_median\", \"sh_angpsw_std\",\n", " \"hdrop_num_abs\", \"hdrop_num_rel\",\n", " \"under_beam\"]" ] }, { "cell_type": "code", "execution_count": 4, "id": "2d7b493d", "metadata": {}, "outputs": [], "source": [ "# name two or three files to compare with classifier\n", "df1 = MOPstroke_P3\n", "df2 = MOPMOSsham_P3\n", "df3 = pd.DataFrame()" ] }, { "cell_type": "code", "execution_count": 5, "id": "442cd87f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
timefp_cycle_dur_medianfp_stance_dur_medianfp_swing_dur_medianfp_sw_st_ratio_medianfp_stride_len_medianhp_cycle_dur_medianhp_stance_dur_medianhp_swing_dur_medianhp_sw_st_ratio_median...sh_anginit_meansh_anginit_mediansh_anginit_stdsh_angpsw_minsh_angpsw_maxsh_angpsw_meansh_angpsw_mediansh_angpsw_stdhdrop_num_absunder_beam
count101.000000101.000000101.000000101.000000101.000000100.000000101.000000101.000000101.000000101.000000...101.000000101.000000101.000000101.000000101.000000101.000000101.000000101.000000101.000000101.000000
mean8.5694600.4127880.0693160.3306554.8970410.1608960.4825510.0637830.4187679.509027...220.932931222.14076817.251572189.510422246.704918219.029805220.03947216.2166531.8118810.051994
std5.1727310.2543250.0155010.2474973.4856590.1635860.4520340.0132000.44688411.704700...8.8887979.5524784.92115910.2242949.7456548.0497928.6416543.7515672.5048470.081752
min3.7652730.1930910.0451370.1287271.666667-0.0198210.2216670.0380310.1526143.267857...195.964708196.8716718.015501162.994557222.802952201.183966196.7083277.8016420.0000000.000000
25%5.6276320.2714530.0642310.1969533.0000000.0443240.2948820.0541670.2378765.563866...216.375162215.78833713.338026182.895660241.514277213.950086216.23729813.1781680.0000000.000000
50%7.3500810.3333330.0666670.2500004.0000000.1315680.3595670.0619050.2995356.957738...221.016359223.75401317.035766186.846959246.307548219.075362220.46348616.6415291.0000000.017544
75%9.0333330.4333330.0691050.3577655.5000000.2204950.4842110.0720940.4210539.146875...226.354558228.80122520.925703193.564390252.753739224.454297225.64923719.6465063.0000000.055172
max42.5090251.6315000.1333331.48522827.2500001.0832753.9066000.1015023.807584103.983810...239.566316240.44101933.949624220.114148268.970073240.642112241.78743225.01310613.0000000.410188
\n", "

8 rows × 218 columns

\n", "
" ], "text/plain": [ " time fp_cycle_dur_median fp_stance_dur_median \\\n", "count 101.000000 101.000000 101.000000 \n", "mean 8.569460 0.412788 0.069316 \n", "std 5.172731 0.254325 0.015501 \n", "min 3.765273 0.193091 0.045137 \n", "25% 5.627632 0.271453 0.064231 \n", "50% 7.350081 0.333333 0.066667 \n", "75% 9.033333 0.433333 0.069105 \n", "max 42.509025 1.631500 0.133333 \n", "\n", " fp_swing_dur_median fp_sw_st_ratio_median fp_stride_len_median \\\n", "count 101.000000 101.000000 100.000000 \n", "mean 0.330655 4.897041 0.160896 \n", "std 0.247497 3.485659 0.163586 \n", "min 0.128727 1.666667 -0.019821 \n", "25% 0.196953 3.000000 0.044324 \n", "50% 0.250000 4.000000 0.131568 \n", "75% 0.357765 5.500000 0.220495 \n", "max 1.485228 27.250000 1.083275 \n", "\n", " hp_cycle_dur_median hp_stance_dur_median hp_swing_dur_median \\\n", "count 101.000000 101.000000 101.000000 \n", "mean 0.482551 0.063783 0.418767 \n", "std 0.452034 0.013200 0.446884 \n", "min 0.221667 0.038031 0.152614 \n", "25% 0.294882 0.054167 0.237876 \n", "50% 0.359567 0.061905 0.299535 \n", "75% 0.484211 0.072094 0.421053 \n", "max 3.906600 0.101502 3.807584 \n", "\n", " hp_sw_st_ratio_median ... sh_anginit_mean sh_anginit_median \\\n", "count 101.000000 ... 101.000000 101.000000 \n", "mean 9.509027 ... 220.932931 222.140768 \n", "std 11.704700 ... 8.888797 9.552478 \n", "min 3.267857 ... 195.964708 196.871671 \n", "25% 5.563866 ... 216.375162 215.788337 \n", "50% 6.957738 ... 221.016359 223.754013 \n", "75% 9.146875 ... 226.354558 228.801225 \n", "max 103.983810 ... 239.566316 240.441019 \n", "\n", " sh_anginit_std sh_angpsw_min sh_angpsw_max sh_angpsw_mean \\\n", "count 101.000000 101.000000 101.000000 101.000000 \n", "mean 17.251572 189.510422 246.704918 219.029805 \n", "std 4.921159 10.224294 9.745654 8.049792 \n", "min 8.015501 162.994557 222.802952 201.183966 \n", "25% 13.338026 182.895660 241.514277 213.950086 \n", "50% 17.035766 186.846959 246.307548 219.075362 \n", "75% 20.925703 193.564390 252.753739 224.454297 \n", "max 33.949624 220.114148 268.970073 240.642112 \n", "\n", " sh_angpsw_median sh_angpsw_std hdrop_num_abs under_beam \n", "count 101.000000 101.000000 101.000000 101.000000 \n", "mean 220.039472 16.216653 1.811881 0.051994 \n", "std 8.641654 3.751567 2.504847 0.081752 \n", "min 196.708327 7.801642 0.000000 0.000000 \n", "25% 216.237298 13.178168 0.000000 0.000000 \n", "50% 220.463486 16.641529 1.000000 0.017544 \n", "75% 225.649237 19.646506 3.000000 0.055172 \n", "max 241.787432 25.013106 13.000000 0.410188 \n", "\n", "[8 rows x 218 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
timefp_cycle_dur_medianfp_stance_dur_medianfp_swing_dur_medianfp_sw_st_ratio_medianfp_stride_len_medianhp_cycle_dur_medianhp_stance_dur_medianhp_swing_dur_medianhp_sw_st_ratio_median...sh_anginit_meansh_anginit_mediansh_anginit_stdsh_angpsw_minsh_angpsw_maxsh_angpsw_meansh_angpsw_mediansh_angpsw_stdhdrop_num_absunder_beam
count83.00000083.00000083.00000083.00000083.00000083.00000083.00000083.00000083.00000083.000000...83.00000083.00000083.00000083.00000083.00000083.00000083.00000083.00000083.00000083.000000
mean6.2229110.3095280.0724730.2250633.3701520.1931700.3164020.0583560.2580465.898642...223.165525223.81568214.491601195.040818245.178300221.084253221.09715714.0393030.7108430.013900
std2.0994320.0908080.0194370.0934641.7895580.1397020.0916260.0112850.0895632.288775...11.88038812.6518134.03164914.36591111.65814011.93048412.8111833.7909662.4719410.045644
min3.3755870.1892850.0319350.0929801.000000-0.0076450.1892630.0403510.1357752.965789...197.137017195.6388957.228378168.859958219.366591198.863748199.4165638.6999100.0000000.000000
25%4.8662140.2499210.0640490.1636352.0000000.0831550.2509150.0516520.1971524.304436...213.990719214.14950111.676275183.546411237.482189212.143198212.83524210.8600170.0000000.000000
50%5.5333330.2874110.0666670.2000003.0000000.1665510.2888650.0561640.2336165.171875...222.355070221.84371013.835482189.553503244.588448218.782297218.98080912.9566470.0000000.000000
75%6.8335900.3560650.0816530.2631614.5000000.2950550.3721150.0632280.3112867.202478...232.538214232.79568816.909589207.006850255.574905231.022127232.22474916.8328640.0000000.000000
max11.3267170.5509720.1574940.50000010.0000000.6455010.5831930.0998840.48330913.336508...246.930219247.68836524.023483233.153303269.994174248.686375249.72434723.33456015.0000000.284091
\n", "

8 rows × 218 columns

\n", "
" ], "text/plain": [ " time fp_cycle_dur_median fp_stance_dur_median \\\n", "count 83.000000 83.000000 83.000000 \n", "mean 6.222911 0.309528 0.072473 \n", "std 2.099432 0.090808 0.019437 \n", "min 3.375587 0.189285 0.031935 \n", "25% 4.866214 0.249921 0.064049 \n", "50% 5.533333 0.287411 0.066667 \n", "75% 6.833590 0.356065 0.081653 \n", "max 11.326717 0.550972 0.157494 \n", "\n", " fp_swing_dur_median fp_sw_st_ratio_median fp_stride_len_median \\\n", "count 83.000000 83.000000 83.000000 \n", "mean 0.225063 3.370152 0.193170 \n", "std 0.093464 1.789558 0.139702 \n", "min 0.092980 1.000000 -0.007645 \n", "25% 0.163635 2.000000 0.083155 \n", "50% 0.200000 3.000000 0.166551 \n", "75% 0.263161 4.500000 0.295055 \n", "max 0.500000 10.000000 0.645501 \n", "\n", " hp_cycle_dur_median hp_stance_dur_median hp_swing_dur_median \\\n", "count 83.000000 83.000000 83.000000 \n", "mean 0.316402 0.058356 0.258046 \n", "std 0.091626 0.011285 0.089563 \n", "min 0.189263 0.040351 0.135775 \n", "25% 0.250915 0.051652 0.197152 \n", "50% 0.288865 0.056164 0.233616 \n", "75% 0.372115 0.063228 0.311286 \n", "max 0.583193 0.099884 0.483309 \n", "\n", " hp_sw_st_ratio_median ... sh_anginit_mean sh_anginit_median \\\n", "count 83.000000 ... 83.000000 83.000000 \n", "mean 5.898642 ... 223.165525 223.815682 \n", "std 2.288775 ... 11.880388 12.651813 \n", "min 2.965789 ... 197.137017 195.638895 \n", "25% 4.304436 ... 213.990719 214.149501 \n", "50% 5.171875 ... 222.355070 221.843710 \n", "75% 7.202478 ... 232.538214 232.795688 \n", "max 13.336508 ... 246.930219 247.688365 \n", "\n", " sh_anginit_std sh_angpsw_min sh_angpsw_max sh_angpsw_mean \\\n", "count 83.000000 83.000000 83.000000 83.000000 \n", "mean 14.491601 195.040818 245.178300 221.084253 \n", "std 4.031649 14.365911 11.658140 11.930484 \n", "min 7.228378 168.859958 219.366591 198.863748 \n", "25% 11.676275 183.546411 237.482189 212.143198 \n", "50% 13.835482 189.553503 244.588448 218.782297 \n", "75% 16.909589 207.006850 255.574905 231.022127 \n", "max 24.023483 233.153303 269.994174 248.686375 \n", "\n", " sh_angpsw_median sh_angpsw_std hdrop_num_abs under_beam \n", "count 83.000000 83.000000 83.000000 83.000000 \n", "mean 221.097157 14.039303 0.710843 0.013900 \n", "std 12.811183 3.790966 2.471941 0.045644 \n", "min 199.416563 8.699910 0.000000 0.000000 \n", "25% 212.835242 10.860017 0.000000 0.000000 \n", "50% 218.980809 12.956647 0.000000 0.000000 \n", "75% 232.224749 16.832864 0.000000 0.000000 \n", "max 249.724347 23.334560 15.000000 0.284091 \n", "\n", "[8 rows x 218 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df1.describe()\n", "df2.describe()\n", "#df3.describe()" ] }, { "cell_type": "markdown", "id": "78f941a6", "metadata": {}, "source": [ "## Data pre-processing" ] }, { "cell_type": "code", "execution_count": 6, "id": "c9b13464", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "1" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "0.0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "MOPstroke_P3 100\n", "MOPMOSsham_P3 82\n", "Name: group, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check for NaN values\n", "df1.isnull().sum().sum()\n", "df2.isnull().sum().sum()\n", "df3.isnull().sum().sum()\n", "\n", "# drop rows with missing values\n", "df1.dropna(inplace=True)\n", "df2.dropna(inplace=True)\n", "df3.dropna(inplace=True)\n", "\n", "# concat dataframes\n", "df = pd.concat([df1, df2, df3])\n", "df[\"group\"].value_counts()\n", "\n", "# remove unnecessary columns\n", "df = df.drop(['file_name', 'reason_end'], axis=1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "38e8620a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LabelEncoder()" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# separate data in response and feature variables\n", "X = df.drop(\"group\", axis=1)\n", "y = df[\"group\"]\n", "# make sure all series in X are numeric\n", "X = X.apply(pd.to_numeric)\n", "# encode labels in y as numerals\n", "le = LabelEncoder()\n", "le.fit([\"MOPstroke_BL\", \"MOPstroke_P3\", \"MOPstroke_P28\", \"MOSstroke_BL\", \"MOSstroke_P3\", \"MOSstroke_P28\", \"WMstroke_BL\", \"WMstroke_P3\", \"WMstroke_P28\", \"MOPMOSsham_BL\", \"MOPMOSsham_P3\", \"MOPMOSsham_P28\", \"WMsham_BL\", \"WMsham_P3\", \"WMsham_P28\"])\n", "y = le.transform(y)\n" ] }, { "cell_type": "markdown", "id": "9220446a", "metadata": {}, "source": [ "## Random Forest Classifer" ] }, { "cell_type": "code", "execution_count": 8, "id": "f3537230", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iterations: 10/100\n", "iterations: 20/100\n", "iterations: 30/100\n", "iterations: 40/100\n", "iterations: 50/100\n", "iterations: 60/100\n", "iterations: 70/100\n", "iterations: 80/100\n", "iterations: 90/100\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[0mX_test\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;31m# Random Forest Classifier\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mrfc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mRandomForestClassifier\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn_estimators\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m2000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbootstrap\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 13\u001b[0m \u001b[0mpred_rfc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrfc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpred_rfc\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[0macc_score\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\sklearn\\ensemble\\_forest.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[0;32m 385\u001b[0m \u001b[1;31m# parallel_backend contexts set at a higher level,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 386\u001b[0m \u001b[1;31m# since correctness does not rely on using threads.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 387\u001b[1;33m trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,\n\u001b[0m\u001b[0;32m 388\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0m_joblib_parallel_args\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprefer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'threads'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 389\u001b[0m delayed(_parallel_build_trees)(\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, iterable)\u001b[0m\n\u001b[0;32m 1042\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1043\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1044\u001b[1;33m \u001b[1;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1045\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1046\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[1;34m(self, iterator)\u001b[0m\n\u001b[0;32m 857\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 858\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 859\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 860\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 861\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[1;34m(self, batch)\u001b[0m\n\u001b[0;32m 775\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 776\u001b[0m \u001b[0mjob_idx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 777\u001b[1;33m \u001b[0mjob\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 778\u001b[0m \u001b[1;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 779\u001b[0m \u001b[1;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[1;34m(self, func, callback)\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 207\u001b[0m \u001b[1;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 208\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 209\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 210\u001b[0m \u001b[0mcallback\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, batch)\u001b[0m\n\u001b[0;32m 570\u001b[0m \u001b[1;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 571\u001b[0m \u001b[1;31m# arguments in memory\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 572\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 573\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 574\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 260\u001b[0m \u001b[1;31m# change the default number of processes to -1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 261\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 262\u001b[1;33m return [func(*args, **kwargs)\n\u001b[0m\u001b[0;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[0;32m 264\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\joblib\\parallel.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 260\u001b[0m \u001b[1;31m# change the default number of processes to -1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 261\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 262\u001b[1;33m return [func(*args, **kwargs)\n\u001b[0m\u001b[0;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[0;32m 264\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\sklearn\\utils\\fixes.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 220\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 221\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 222\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\sklearn\\ensemble\\_forest.py\u001b[0m in \u001b[0;36m_parallel_build_trees\u001b[1;34m(tree, forest, X, y, sample_weight, tree_idx, n_trees, verbose, class_weight, n_samples_bootstrap)\u001b[0m\n\u001b[0;32m 167\u001b[0m indices=indices)\n\u001b[0;32m 168\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 169\u001b[1;33m \u001b[0mtree\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcurr_sample_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcheck_input\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 170\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[0mtree\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcheck_input\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\sklearn\\tree\\_classes.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight, check_input, X_idx_sorted)\u001b[0m\n\u001b[0;32m 901\u001b[0m \"\"\"\n\u001b[0;32m 902\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 903\u001b[1;33m super().fit(\n\u001b[0m\u001b[0;32m 904\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 905\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\envs\\KinematicAnalyses\\lib\\site-packages\\sklearn\\tree\\_classes.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight, check_input, X_idx_sorted)\u001b[0m\n\u001b[0;32m 392\u001b[0m min_impurity_split)\n\u001b[0;32m 393\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 394\u001b[1;33m \u001b[0mbuilder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 395\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 396\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_outputs_\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mis_classifier\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "acc_score = 0\n", "for i in range(100):\n", " if (i+1) % 10 == 0:\n", " print(\"iterations: {}/100\".format(i+1))\n", " # split data in train and test set\n", " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n", " # Scaling data\n", " sc = StandardScaler()\n", " X_train = sc.fit_transform(X_train)\n", " X_test = sc.transform(X_test) \n", " # Random Forest Classifier\n", " rfc = RandomForestClassifier(n_estimators=2000, bootstrap=True).fit(X_train, y_train)\n", " pred_rfc = rfc.predict(X_test)\n", " if accuracy_score(y_test, pred_rfc) > acc_score:\n", " best_model = rfc\n", " acc_score = accuracy_score(y_test, pred_rfc)\n", " data_split = [X_train, X_test, y_train, y_test]" ] }, { "cell_type": "code", "execution_count": null, "id": "f607b36f", "metadata": {}, "outputs": [], "source": [ "rfc = best_model\n", "X_train, X_test, y_train, y_test = data_split\n", "sc = StandardScaler()\n", "X_train = sc.fit_transform(X_train)\n", "X_test = sc.transform(X_test)\n", "pred_rfc = rfc.predict(X_test)\n", "\n", "print(classification_report(y_test, pred_rfc))\n", "print(accuracy_score(y_test, pred_rfc))\n", "print(balanced_accuracy_score(y_test, pred_rfc))\n", "print(f1_score(y_test, pred_rfc, average='weighted'))" ] }, { "cell_type": "code", "execution_count": null, "id": "689a6a69", "metadata": {}, "outputs": [], "source": [ "selector = SelectFromModel(rfc, prefit=True)\n", "f_imp_df = pd.DataFrame({\"features\": feature_names, \"importance_scores\": rfc.feature_importances_, \"support\": selector.get_support()})\n", "f_imp_df_sorted = f_imp_df.sort_values('importance_scores', ascending=False)\n", "num_features = f_imp_df_sorted[\"support\"].sum()\n", "f_imp_df_sorted.to_csv(\"C:/Users/franz/Documents/Bachelor Studium/BachelorArbeit/JupyterLabBA/files/RandomForest/feature_importances{}.csv\".format(le.inverse_transform(rfc.classes_)), index = False)" ] }, { "cell_type": "markdown", "id": "20deb056", "metadata": {}, "source": [ "## Plot Function" ] }, { "cell_type": "code", "execution_count": null, "id": "122cb9d9", "metadata": {}, "outputs": [], "source": [ "def plot_feature_importances(scores, labels):\n", " scores = scores[scores[\"support\"]]\n", " fig = plt.figure(figsize=(10,5))\n", " if len(labels) == 3 :\n", " fig = fig.suptitle(\"{} vs. {} vs. {}: Feature importance scores of {} selected features (acc. {:.3f})\".format(labels[0], labels[1], labels[2], num_features, accuracy_score(y_test, pred_rfc)))\n", " else:\n", " fig = fig.suptitle(\"{} vs. {}: Feature importance scores of {} selected features (acc. {:.3f})\".format(labels[0], labels[1], num_features, accuracy_score(y_test, pred_rfc)))\n", " fig = plt.tick_params('x', labelrotation=90)\n", " fig = plt.bar(scores[\"features\"], scores[\"importance_scores\"])\n", " fig = plt.ylabel(\"feature importance score\")\n", " fig = plt.tight_layout()\n", " plt.savefig(\"C:/Users/franz/Documents/Bachelor Studium/BachelorArbeit/JupyterLabBA/files/RandomForest/plot_feature_importance_{}.pdf\".format(le.inverse_transform(rfc.classes_)))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "871642d9", "metadata": {}, "outputs": [], "source": [ "plot_confusion_matrix(rfc, X_test, y_test, display_labels=le.inverse_transform(rfc.classes_), cmap=plt.cm.Blues, normalize='true')\n", "plt.savefig(\"C:/Users/franz/Documents/Bachelor Studium/BachelorArbeit/JupyterLabBA/files/RandomForest/confusion_matrix_{}.png\".format(le.inverse_transform(rfc.classes_)))\n", "plot_confusion_matrix(rfc, X_test, y_test, display_labels=le.inverse_transform(rfc.classes_), cmap=plt.cm.Blues, normalize='false')\n", "plot_feature_importances(scores=f_imp_df_sorted, labels=le.inverse_transform(rfc.classes_))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }