rand_index.m 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. function ri = rand_index(p1, p2, varargin)
  2. %RAND_INDEX Computes the rand index between two partitions.
  3. % RAND_INDEX(p1, p2) computes the rand index between partitions p1 and
  4. % p2. Both p1 and p2 must be specified as N-by-1 or 1-by-N vectors in
  5. % which each elements is an integer indicating which cluster the point
  6. % belongs to.
  7. %
  8. % RAND_INDEX(p1, p2, 'adjusted') computes the adjusted rand index
  9. % between partitions p1 and p2. The adjustment accounts for chance
  10. % correlation.
  11. %% Parse the input and throw errors
  12. % Check inputs
  13. adj = 0;
  14. if nargin == 0
  15. error('Arguments must be supplied.');
  16. end
  17. if nargin == 1
  18. error('Two partitions must be supplied.');
  19. end
  20. if nargin > 3
  21. error('Too many input arguments');
  22. end
  23. if nargin == 3
  24. if strcmp(varargin{1}, 'adjusted')
  25. adj = 1;
  26. else
  27. error('%s is an unrecognized argument.', varargin{1});
  28. end
  29. end
  30. if length(p1)~=length(p2)
  31. error('Both partitions must contain the same number of points.');
  32. end
  33. % Check if arguments need to be flattened
  34. if length(p1)~=numel(p1)
  35. p1 = p1(:);
  36. warning('The first partition was flattened to a 1D vector.')
  37. end
  38. if length(p2)~=numel(p2)
  39. p2 = p2(:);
  40. warning('The second partition was flattened to a 1D vector.')
  41. end
  42. % Check for integers
  43. if isreal(p1) && all(rem(p1, 1)==0)
  44. % all is well
  45. else
  46. warning('The first partition contains non-integers, which may make the results meaningless. Attempting to continue calculations.');
  47. end
  48. if isreal(p2) && all(rem(p2, 1)==0)
  49. % all is well
  50. else
  51. warning('The second partition contains non-integers, which may make the results meaningless. Attempting to continue calculations.');
  52. end
  53. %% Preliminary computations and cleansing of the partitions
  54. N = length(p1);
  55. [~, ~, p1] = unique(p1);
  56. N1 = max(p1);
  57. [~, ~, p2] = unique(p2);
  58. N2 = max(p2);
  59. %% Create the matching matrix
  60. for i=1:1:N1
  61. for j=1:1:N2
  62. G1 = find(p1==i);
  63. G2 = find(p2==j);
  64. n(i,j) = length(intersect(G1,G2));
  65. end
  66. end
  67. %% If required, calculate the basic rand index
  68. if adj==0
  69. ss = sum(sum(n.^2));
  70. ss1 = sum(sum(n,1).^2);
  71. ss2 =sum(sum(n,2).^2);
  72. ri = (nchoosek2(N,2) + ss - 0.5*ss1 - 0.5*ss2)/nchoosek2(N,2);
  73. end
  74. %% Otherwise, calculate the adjusted rand index
  75. if adj==1
  76. ssm = 0;
  77. sm1 = 0;
  78. sm2 = 0;
  79. for i=1:1:N1
  80. for j=1:1:N2
  81. ssm = ssm + nchoosek2(n(i,j),2);
  82. end
  83. end
  84. temp = sum(n,2);
  85. for i=1:1:N1
  86. sm1 = sm1 + nchoosek2(temp(i),2);
  87. end
  88. temp = sum(n,1);
  89. for i=1:1:N2
  90. sm2 = sm2 + nchoosek2(temp(i),2);
  91. end
  92. NN = ssm - sm1*sm2/nchoosek2(N,2);
  93. DD = (sm1 + sm2)/2 - sm1*sm2/nchoosek2(N,2);
  94. ri = NN/DD;
  95. end
  96. %% Special definition of n choose k
  97. function c = nchoosek2(a,b)
  98. if a>1
  99. c = nchoosek(a,b);
  100. else
  101. c = 0;
  102. end
  103. end
  104. end