ROC.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from collections import namedtuple
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. def scale(v, vmin=-180, vmax=180):
  5. return (v - vmin) / (vmax - vmin)
  6. def by_small(v, frac=0.001, add=True):
  7. by = abs(v)*frac
  8. if add == False:
  9. by = -by
  10. return v + by
  11. def reported(v):
  12. print(v)
  13. return v
  14. def compute(positive, negative):
  15. merged_data = [(v, True) for v in positive] + \
  16. [(v, False) for v in negative]
  17. merged_data = sorted(merged_data, key=lambda v: v[0])
  18. sorted_values = np.array([v[0] for v in merged_data])
  19. sorted_flags = np.array([v[1] for v in merged_data])
  20. true_total = np.count_nonzero(sorted_flags)
  21. false_total = sorted_flags.size - true_total
  22. thresholds = [sorted_values.min()]
  23. sensitivity = [1]
  24. nonspecific = [1]
  25. above_init = np.where(sorted_values > sorted_values.min())[0]
  26. if above_init.size > 0:
  27. offset = above_init.min()
  28. while offset < sorted_values.size:
  29. threshold = sorted_values[offset]
  30. detected_positive = sorted_flags[offset:]
  31. true_positive = np.count_nonzero(detected_positive)
  32. false_positive = np.count_nonzero(~detected_positive)
  33. thresholds.append(threshold)
  34. sensitivity.append(true_positive / true_total)
  35. nonspecific.append(false_positive / false_total)
  36. # compute next offset
  37. ceiling = sorted_values[:(offset+1)].max()
  38. above = sorted_values > ceiling
  39. if np.count_nonzero(above) == 0:
  40. break
  41. offset = np.where(above)[0].min()
  42. thresholds.append(by_small(sorted_values.max()))
  43. sensitivity.append(0)
  44. nonspecific.append(0)
  45. return ROC(np.array(thresholds)[::-1],
  46. np.array(sensitivity)[::-1],
  47. np.array(nonspecific)[::-1])
  48. class ROC(namedtuple("_ROC", ("thresholds", "sensitivity", "nonspecific"))):
  49. @property
  50. def AUC(self):
  51. return sum((base1+base2)*abs(width)/2 \
  52. for base1, base2, width in zip(self.sensitivity[1:], self.sensitivity[:-1],
  53. (self.nonspecific[1:] - self.nonspecific[:-1])))
  54. def plot(self, ax, cmap=plt.get_cmap("viridis"),
  55. vmin=-180, vmax=180, alpha=1, markersize=10):
  56. for t, n, s in zip(scale(self.thresholds, vmin=vmin, vmax=vmax), self.nonspecific, self.sensitivity):
  57. ax.plot((n,), (s,), ".", color=cmap(t)[:3], alpha=alpha, mec="none", ms=markersize)