phase_analysis.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.cluster import KMeans
  5. def phase_clustering(df):
  6. """Clusters the phases using K-means and shifts the right cluster by 2 pi.
  7. Parameters
  8. ----------
  9. df : pandas.DataFrame
  10. The joined DataFrame of baseline properties and receptive field locations.
  11. Returns
  12. -------
  13. pandas.DataFrame
  14. The dataframe with three more columns, the shifted phase, the cluster label and the absolute delay.
  15. """
  16. rfpositions = df.receptor_pos_absolute.values
  17. phases = df.phase + np.pi
  18. x = np.asarray([rfpositions/(np.max(rfpositions) - np.min(rfpositions)), 0.75 * phases/(2 * np.pi)]).T
  19. kmeans = KMeans(n_clusters=2, n_init=200)
  20. kmeans.fit(x)
  21. cluster_ids = kmeans.labels_
  22. phase_shifted = 1 * phases
  23. if rfpositions[cluster_ids == 0].mean() > rfpositions[cluster_ids == 1].mean():
  24. phase_shifted[cluster_ids == 0] += 2 * np.pi
  25. else:
  26. phase_shifted[cluster_ids == 1] += 2 * np.pi
  27. df["kmeans_label"] = cluster_ids
  28. df["phase_shifted"] = phase_shifted - np.pi
  29. df["phase_time"] = df["phase_shifted"].values / (2 * np.pi) * df.eod_period
  30. # import matplotlib.pyplot as plt
  31. # from scipy.stats import pearsonr
  32. # plt.plot(rfpositions[cluster_ids==1], phases[cluster_ids == 1], c="r", ls="None", marker=".")
  33. # plt.plot(rfpositions[cluster_ids==0], phases[cluster_ids == 0], c="b", ls="None", marker=".")
  34. # plt.show()
  35. # plt.plot(df.receptor_pos_relative, df.phase_shifted, c="k", ls="None", marker=".")
  36. # plt.show()
  37. # print(pearsonr(df.receptor_pos_relative, df.phase_shifted))
  38. # plt.plot(df.receptor_pos_absolute, df.phase_time, c="k", ls="None", marker=".")
  39. # plt.show()
  40. # print(pearsonr(df.receptor_pos_absolute, df.phase_time))
  41. return df
  42. def phase_analysis(data_folder):
  43. baseline_df = pd.read_csv(os.path.join(data_folder, "baseline_properties.csv"), sep=";", index_col=0)
  44. receptivefield_df = pd.read_csv(os.path.join(data_folder, "receptivefield_positions.csv"), sep=";", index_col=0)
  45. joined_df = baseline_df.merge(receptivefield_df)
  46. df = phase_clustering(joined_df)
  47. return df