|
@@ -0,0 +1,375 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+Created on Mon May 11 16:55:37 2020
|
|
|
+
|
|
|
+@author: Cecilia
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+"""
|
|
|
+Version 1: original plotting (bar plot).
|
|
|
+
|
|
|
+Version 2: use line plot.
|
|
|
+
|
|
|
+Version 3: change outlooks of the plots in version 2.
|
|
|
+
|
|
|
+
|
|
|
+"""
|
|
|
+
|
|
|
+import statsmodels.api as sm
|
|
|
+import os
|
|
|
+import pandas as pd
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+import seaborn as sns
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+#%
|
|
|
+## Get the path of the folder where the data csv file sits in.
|
|
|
+## The results will save to this folder.
|
|
|
+
|
|
|
+## path in desktop
|
|
|
+csv_path = r"C:\Users\Cecilia\OneDrive\Python\TWF_behaviour\behaviouralRecalculated\TWF_ITD_dataset_withoutCorrectionTrials.csv"
|
|
|
+if not os.path.isfile(csv_path):
|
|
|
+ csv_path=r"C:\jan\document\science\cityu\TWF_ITD_ILD_NormalHearingRats\figures\TWF_ITD_dataset_withoutCorrectionTrials.csv"
|
|
|
+if not os.path.isfile(csv_path):
|
|
|
+ csv_path=r"/home/jan/jan/document/science/cityu/TWF_ITD_ILD_NormalHearingRats/figures/TWF_ITD_dataset_withoutCorrectionTrials.csv"
|
|
|
+## path in laptop
|
|
|
+#csv_path = r"C:\Users\USER\OneDrive\Python\TWF_behaviour\behaviouralRecalculated\TWF_ITD_dataset_withoutCorrectionTrials.csv"
|
|
|
+
|
|
|
+folder_path, csv_filename = os.path.split(csv_path)
|
|
|
+
|
|
|
+
|
|
|
+## Load data file (without correction trials)
|
|
|
+
|
|
|
+dataset = pd.read_csv(csv_path,encoding='utf-8-sig', sep='\s*,\s*', engine='python').copy()
|
|
|
+
|
|
|
+## Get rid of the white spaces in front or at the end of collumns. If not doing so, will raise error when calling get_group()
|
|
|
+dataset['animal'] = dataset['animal'].astype(str).str.strip()
|
|
|
+dataset.columns = dataset.columns.str.strip()
|
|
|
+
|
|
|
+## Only need the offset (ms) == 0 (probe) trials
|
|
|
+probe_data = dataset.loc[dataset['offset (ms)'] == 0]
|
|
|
+
|
|
|
+## Get the unique condition and animal ID
|
|
|
+conditions = list(probe_data['condition'].unique())
|
|
|
+animals = list(probe_data['animal'].unique())
|
|
|
+
|
|
|
+
|
|
|
+## Group data by condition and animal
|
|
|
+grouped_data = probe_data.groupby(['condition','animal'])
|
|
|
+
|
|
|
+#%% work out session boundaries and count how many sessions for each animal
|
|
|
+sessEnd=np.where(np.array(dataset.timeStamp[1:])<np.array(dataset.timeStamp[:-1]))[0]
|
|
|
+sessEnd=np.append(sessEnd,[dataset.shape[0]-1])
|
|
|
+sessPerAnimal=np.array(dataset.animal[sessEnd])
|
|
|
+for animal in animals:
|
|
|
+ print('{} sessions for rat {}'.format(np.sum(sessPerAnimal==animal),animal))
|
|
|
+
|
|
|
+#%%
|
|
|
+## Loop throught condition and animal
|
|
|
+result_df = pd.DataFrame([])
|
|
|
+for (condition, animal), group in grouped_data:
|
|
|
+ print('==========')
|
|
|
+ print('Now analyzing {} {}'.format(condition, animal))
|
|
|
+
|
|
|
+ probit_data = group
|
|
|
+
|
|
|
+ ## Extract the ITD_values (X) and side (y) for analysis
|
|
|
+ ITD_values, side = probit_data.filter(regex="ITD"),probit_data['respRight']
|
|
|
+
|
|
|
+ # Probit Model
|
|
|
+ ITD_val_with_const=sm.tools.add_constant(ITD_values)
|
|
|
+ probit_model = sm.Probit(side, ITD_val_with_const)
|
|
|
+
|
|
|
+ probit_result = probit_model.fit()
|
|
|
+
|
|
|
+
|
|
|
+ ## Get the coefficients, standar errors and p values
|
|
|
+ coefficient = probit_result.params.values[1:]
|
|
|
+ standard_error = probit_result.bse[1:]
|
|
|
+ p_value = probit_result.pvalues[1:]
|
|
|
+
|
|
|
+ ## Look at the confidence interval of each coeffecient
|
|
|
+ #print(probit_result.conf_int())
|
|
|
+
|
|
|
+ ## Add new column ['Sig'] and ['Label'] into df
|
|
|
+ sigs = []
|
|
|
+ labels = []
|
|
|
+ for i, value in enumerate(p_value):
|
|
|
+ print(i, value)
|
|
|
+ if value < 0.05 and value >= 0.01:
|
|
|
+ sigs.append(1)
|
|
|
+ labels.append("*")
|
|
|
+
|
|
|
+ if value < 0.01:
|
|
|
+ sigs.append(1)
|
|
|
+ labels.append("**")
|
|
|
+
|
|
|
+ if value >= 0.05:
|
|
|
+ sigs.append(0)
|
|
|
+ labels.append(" ")
|
|
|
+
|
|
|
+
|
|
|
+ ## Save the coefficients, standar errors, p values and sig into datafram and csv file
|
|
|
+
|
|
|
+ df = pd.DataFrame({"Condition": condition,
|
|
|
+ "Animal": animal,
|
|
|
+ "Coefficient" : coefficient,
|
|
|
+ "Standard error": standard_error,
|
|
|
+ "P value": p_value,
|
|
|
+ "Sig": sigs,
|
|
|
+ "Label": labels})
|
|
|
+ df['Click number'] = range(1,9)
|
|
|
+
|
|
|
+ result_df = pd.concat([result_df, df], axis = 0)
|
|
|
+ result_df['Number index'] = range(result_df.shape[0])
|
|
|
+
|
|
|
+ ## Save result_df to csv file
|
|
|
+ result_df.to_csv(folder_path + '/TWF_ITD_probit_results.csv')
|
|
|
+
|
|
|
+ ## Get the probit summary table
|
|
|
+ probit_result_table = probit_result.summary()
|
|
|
+
|
|
|
+ ## Save the results to txt file
|
|
|
+ with open(folder_path + '/TWF_ITD_probit_results_summary.txt', "a") as output_file:
|
|
|
+
|
|
|
+ print("\n ", file = output_file)
|
|
|
+ print("Rat {} at {}".format(animal, condition), file = output_file)
|
|
|
+ print("= = = = = = = = =", file = output_file)
|
|
|
+ print(probit_result_table, file = output_file)
|
|
|
+
|
|
|
+ print("Done")
|
|
|
+
|
|
|
+
|
|
|
+print("All DONE")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+#%% Plot the results
|
|
|
+
|
|
|
+## Line plot for each rat in each condition
|
|
|
+
|
|
|
+#%%
|
|
|
+## If run by loading the result csv file, uncomment this cell
|
|
|
+
|
|
|
+result_file = folder_path+r'/TWF_ITD_probit_results.csv'
|
|
|
+result_df = pd.read_csv(result_file,encoding='utf-8-sig', sep='\s*,\s*', engine='python').copy()
|
|
|
+result_path, result_filename = os.path.split(result_file)
|
|
|
+
|
|
|
+
|
|
|
+#%%
|
|
|
+
|
|
|
+## figure 2 in TWF paper
|
|
|
+
|
|
|
+
|
|
|
+plt.close("all")
|
|
|
+grouped_result = result_df.groupby(['Condition', 'Animal']) #(['Condition', 'Animal'])
|
|
|
+conditions = ['20Hz', '50Hz', '300Hz', '900Hz']
|
|
|
+
|
|
|
+# Set the font dictionaries (for plot title, axis titles and text)
|
|
|
+title_font = {'size':'13', ## 'fontname':'Arial', 'color':'black', 'size':'52' 16
|
|
|
+ 'verticalalignment':'bottom'} # Bottom vertical alignment for more space
|
|
|
+axis_font = {'size':'12'} ## 'size':'66' 15
|
|
|
+text_font = {'size': '12'} ## 'size': '66' #'weight':'bold', 14
|
|
|
+
|
|
|
+fig, axs = plt.subplots(1,4,sharex = True, sharey = True, squeeze = False)
|
|
|
+
|
|
|
+
|
|
|
+axs = axs.flatten()
|
|
|
+
|
|
|
+for (condition, animal), group in grouped_result:
|
|
|
+
|
|
|
+ i = conditions.index(condition)
|
|
|
+
|
|
|
+ axs[i].plot(group['Click number'], group['Coefficient'], '.-',label = animal,linewidth = 2)
|
|
|
+ axs[i].set_ylim([-2,15])
|
|
|
+ # axs[i,0].legend(loc = 1)
|
|
|
+ axs[i].set_title(condition[0:-2]+' '+condition[-2:], **title_font)
|
|
|
+
|
|
|
+ ## add a stippled black line for y=0
|
|
|
+ axs[i].axhline(y=0, color='black', linestyle='dashed')
|
|
|
+ axs[i].set_xticks([1,2,3,4,5,6,7,8])
|
|
|
+ ## rotate the x tick labels
|
|
|
+ # axs[i].set_xticklabels(labels = group['Click number'],rotation = 60, ha="right")
|
|
|
+
|
|
|
+ ## Add significant marks (solid circle)
|
|
|
+ ## Create a list of index to mark in the lineplot
|
|
|
+
|
|
|
+ mark_index = []
|
|
|
+ for index, p_value in enumerate(group["P value"]):
|
|
|
+ if p_value < 0.05:
|
|
|
+ mark_index.append(index)
|
|
|
+
|
|
|
+ axs[i].plot(group['Click number'],group['Coefficient'],markevery = mark_index, ls = "", marker = "*", markersize = 8, c = "k") #, label="points"
|
|
|
+ handles, labels = axs[0].get_legend_handles_labels()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# Place the legend outside of the subplots
|
|
|
+
|
|
|
+leg=fig.legend(handles, labels, loc = 'center right', ncol=1, bbox_to_anchor = (0.96,0.7), title = 'Rat') # bbox_to_anchor = (0.95,0.88)
|
|
|
+frame=leg.get_frame()
|
|
|
+frame.set_facecolor('w')
|
|
|
+frame.set_edgecolor('k')
|
|
|
+# add a big axis, hide frame
|
|
|
+fig.add_subplot(111, frameon=False)
|
|
|
+# hide tick and tick label of the big axis
|
|
|
+plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
|
|
|
+
|
|
|
+
|
|
|
+plt.rc('font', **text_font)
|
|
|
+plt.gcf().set_size_inches(8, 4)
|
|
|
+plt.tight_layout()
|
|
|
+plt.ylabel('Weight β ($ms^{-1}$)',va = 'bottom', **axis_font)
|
|
|
+plt.xlabel('Click #', **axis_font)
|
|
|
+plt.show()
|
|
|
+
|
|
|
+#%%
|
|
|
+# Save the figure
|
|
|
+plt.savefig(folder_path + '/TWF_ITD_lineplot_for_individual_rat_v3.svg')
|
|
|
+
|
|
|
+
|
|
|
+#%%
|
|
|
+
|
|
|
+# figure 2 in TWF paper
|
|
|
+
|
|
|
+## Plot mean ± sem line plot for rats in each condition
|
|
|
+
|
|
|
+
|
|
|
+fig, ax = plt.subplots()
|
|
|
+
|
|
|
+## add a stippled black line for y=0
|
|
|
+ax.axhline(y=0, color='black', linestyle='dashed')
|
|
|
+
|
|
|
+sns.pointplot(x= "Click number",
|
|
|
+ y= "Coefficient",
|
|
|
+ hue = "Condition",
|
|
|
+ style = "Condition",
|
|
|
+ data = result_df,
|
|
|
+ hue_order = ['20Hz', '50Hz', '300Hz', '900Hz'],
|
|
|
+ # estimator = None,
|
|
|
+ # ci = 95,
|
|
|
+ ci = 68, # shows the standard error of the mean
|
|
|
+ n_boot = 1000, # number of bootstrapping to calculate the confidence interval
|
|
|
+ errwidth =1.5,
|
|
|
+ capsize=.1,
|
|
|
+ dodge=True) # to avoid overlaapping of the mean points
|
|
|
+
|
|
|
+
|
|
|
+plt.ylabel('Weight ($ms^{-1}$)',va = 'bottom', **axis_font)
|
|
|
+plt.xlabel('click #', **axis_font)
|
|
|
+plt.rc('font', **text_font)
|
|
|
+plt.tight_layout()
|
|
|
+plt.gcf().set_size_inches(5,5)
|
|
|
+plt.show()
|
|
|
+
|
|
|
+
|
|
|
+#%%
|
|
|
+# Save figure
|
|
|
+plt.savefig(result_path + '/TWF_ITD_pointplot_v3.svg')
|
|
|
+
|
|
|
+#%% Compute bootstrap distributions of probit parameters for each condition, pooling across animals
|
|
|
+## Group data by condition and animal
|
|
|
+poolAnimals_groupConditions = probe_data.groupby(['condition'])
|
|
|
+
|
|
|
+NbootTrials=1000
|
|
|
+#%
|
|
|
+## Loop throught condition and animal
|
|
|
+boot_results=np.zeros((NbootTrials,8))
|
|
|
+x={}
|
|
|
+for i in range(8):
|
|
|
+ for cond in [20,50,300,900]:
|
|
|
+ x['click{}_{}Hz'.format(i,cond)]=[]
|
|
|
+boot_result_df = pd.DataFrame(x)
|
|
|
+for condition, group in poolAnimals_groupConditions:
|
|
|
+ print('==========')
|
|
|
+ print('Now analyzing {} '.format(condition))
|
|
|
+
|
|
|
+ probit_data = group
|
|
|
+
|
|
|
+ ## Extract the ITD_values (X) and side (y) for analysis
|
|
|
+ ITD_values, side = probit_data.filter(regex="ITD"),probit_data['respRight']
|
|
|
+ side=np.array(side)
|
|
|
+ # Probit Model
|
|
|
+ ITD_val_with_const=np.array(sm.tools.add_constant(ITD_values))
|
|
|
+ Ntrials=len(side)
|
|
|
+ #% now do the bootstrap
|
|
|
+ for t in range(NbootTrials):
|
|
|
+ bootSample=np.floor(np.random.uniform(low=0,high=Ntrials,size=Ntrials)).astype(int)
|
|
|
+ probit_model = sm.Probit(side[bootSample], ITD_val_with_const[bootSample,:])
|
|
|
+ probit_result = probit_model.fit(disp=False)
|
|
|
+ boot_results[t,:] = probit_result.params[1:]
|
|
|
+ if np.mod(t,100)==99:
|
|
|
+ print(' Bootstrap step {}'.format(t+1))
|
|
|
+ #% save the bootstrap results to the dataframe
|
|
|
+ for i in range(8):
|
|
|
+ nextCond='click{}_{}'.format(i,condition)
|
|
|
+ print('saving boot results for {}'.format(nextCond))
|
|
|
+ boot_result_df[nextCond]=boot_results[:,i]
|
|
|
+# save bootstrap results to CSV
|
|
|
+boot_result_df.to_csv(folder_path + '/BootstrapResult.csv')
|
|
|
+#%% read bootstrap results if needed.
|
|
|
+boot_result_df = pd.read_csv(folder_path + '/BootstrapResult.csv')
|
|
|
+boot_result_df = boot_result_df.loc[:, ~boot_result_df.columns.str.contains('^Unnamed')]
|
|
|
+#%% plot bootstrap distributions
|
|
|
+plt.figure(2,figsize=(8,5))
|
|
|
+plt.clf()
|
|
|
+ax=plt.subplot(111)
|
|
|
+colors=['r','g','b','m']
|
|
|
+colors=colors+colors
|
|
|
+colors=colors+colors
|
|
|
+colors=colors+colors
|
|
|
+clickRates=[' 20',' 50',300,900]
|
|
|
+columns=boot_result_df.columns
|
|
|
+xTickLabels=[]
|
|
|
+xTicks=[]
|
|
|
+
|
|
|
+ii=0
|
|
|
+pos=1
|
|
|
+for clkIdx in range(8):
|
|
|
+ for HzIdx in range(4):
|
|
|
+ violin_parts=ax.violinplot(boot_result_df[columns[ii]], positions=[pos],showextrema = True, showmedians = True)
|
|
|
+ xTicks.append(pos)
|
|
|
+ xTickLabels.append('click {}:{} Hz'.format(clkIdx+1,clickRates[HzIdx]))
|
|
|
+ for partname in ('cbars','cmins','cmaxes','cmedians'):
|
|
|
+ vp = violin_parts[partname]
|
|
|
+ vp.set_edgecolor(colors[ii])
|
|
|
+ vp.set_linewidth(1)
|
|
|
+
|
|
|
+ # Make the violin body blue with a red border:
|
|
|
+ for vp in violin_parts['bodies']:
|
|
|
+ vp.set_facecolor(colors[ii])
|
|
|
+ vp.set_edgecolor(colors[ii])
|
|
|
+ vp.set_linewidth(1)
|
|
|
+ vp.set_alpha(0.5)
|
|
|
+
|
|
|
+ ii+=1
|
|
|
+ pos+=1
|
|
|
+
|
|
|
+ pos+=1
|
|
|
+
|
|
|
+plt.plot([0,pos],[0,0],'k:')
|
|
|
+ax.set_xlim([0,pos])
|
|
|
+plt.ylabel('Probit Coefficients')
|
|
|
+ax.set_xticks(xTicks)
|
|
|
+plt.xticks(rotation=90)
|
|
|
+ax.set_xticklabels(xTickLabels)
|
|
|
+# v['bodies'][0].set_edgecolor('r')
|
|
|
+# v['bodies'][0].set_facecolor('r')
|
|
|
+#sns.violinplot(data=boot_result_df)
|
|
|
+plt.title('Bootstrapped confidence intervals for probit coefficients')
|
|
|
+plt.tight_layout()
|
|
|
+#%%
|
|
|
+# Save figure
|
|
|
+plt.savefig(result_path + '/TWF_ITD_bootstrapWeights.svg')
|
|
|
+plt.savefig(result_path + '/Figure3.tif', dpi=300)
|
|
|
+
|
|
|
+
|
|
|
+#%% output some descriptive statistics for the bootstrap distributions
|
|
|
+# Fron click 3 to 7, what are the medians?
|
|
|
+clickRates=[20,50,300,900]
|
|
|
+for clkIdx in range(3,8):
|
|
|
+ for HzIdx in range(4):
|
|
|
+ print(' click {} at {} Hz - median weight: {}'.format(clkIdx,clickRates[HzIdx],\
|
|
|
+ np.median(boot_result_df['click{}_{}Hz'.format(clkIdx,clickRates[HzIdx])])))
|