12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- from collections import namedtuple
- import numpy as np
- import matplotlib.pyplot as plt
- def scale(v, vmin=-180, vmax=180):
- return (v - vmin) / (vmax - vmin)
- def by_small(v, frac=0.001, add=True):
- by = abs(v)*frac
- if add == False:
- by = -by
- return v + by
- def reported(v):
- print(v)
- return v
- def compute(positive, negative):
- merged_data = [(v, True) for v in positive] + \
- [(v, False) for v in negative]
- merged_data = sorted(merged_data, key=lambda v: v[0])
- sorted_values = np.array([v[0] for v in merged_data])
- sorted_flags = np.array([v[1] for v in merged_data])
- true_total = np.count_nonzero(sorted_flags)
- false_total = sorted_flags.size - true_total
- thresholds = [sorted_values.min()]
- sensitivity = [1]
- nonspecific = [1]
- above_init = np.where(sorted_values > sorted_values.min())[0]
- if above_init.size > 0:
- offset = above_init.min()
- while offset < sorted_values.size:
- threshold = sorted_values[offset]
- detected_positive = sorted_flags[offset:]
- true_positive = np.count_nonzero(detected_positive)
- false_positive = np.count_nonzero(~detected_positive)
- thresholds.append(threshold)
- sensitivity.append(true_positive / true_total)
- nonspecific.append(false_positive / false_total)
- # compute next offset
- ceiling = sorted_values[:(offset+1)].max()
- above = sorted_values > ceiling
- if np.count_nonzero(above) == 0:
- break
- offset = np.where(above)[0].min()
- thresholds.append(by_small(sorted_values.max()))
- sensitivity.append(0)
- nonspecific.append(0)
- return ROC(np.array(thresholds)[::-1],
- np.array(sensitivity)[::-1],
- np.array(nonspecific)[::-1])
- class ROC(namedtuple("_ROC", ("thresholds", "sensitivity", "nonspecific"))):
- @property
- def AUC(self):
- return sum((base1+base2)*abs(width)/2 \
- for base1, base2, width in zip(self.sensitivity[1:], self.sensitivity[:-1],
- (self.nonspecific[1:] - self.nonspecific[:-1])))
-
-
- def plot(self, ax, cmap=plt.get_cmap("viridis"),
- vmin=-180, vmax=180, alpha=1, markersize=10):
- for t, n, s in zip(scale(self.thresholds, vmin=vmin, vmax=vmax), self.nonspecific, self.sensitivity):
- ax.plot((n,), (s,), ".", color=cmap(t)[:3], alpha=alpha, mec="none", ms=markersize)
|