CLLDA_for_ICLabel_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from load_website_data import load_icl_test
  2. import numpy as np
  3. from crowd_labeling.CLLDA import concurrent_cllda, combine_cllda
  4. import cPickle as pkl
  5. from scipy.io import savemat
  6. # load sqlite data
  7. icl_votes = load_icl_test('database.sqlite')
  8. votes_vec = icl_votes['votes']
  9. votes_vec_workers = icl_votes['workers']
  10. votes_vec_instances = icl_votes['instances']
  11. instance_study_numbers = icl_votes['instance_study_numbers']
  12. instance_set_numbers = icl_votes['instance_set_numbers']
  13. instance_ic_numbers = icl_votes['instance_ic_numbers']
  14. T = icl_votes['n_classes']
  15. C = icl_votes['n_responses']
  16. A = icl_votes['n_workers']
  17. # CLLDA settings
  18. all_priors = np.tile(np.maximum(np.hstack((5*np.eye(T), np.zeros((T, 1)))), 0.01), [A, 1, 1])
  19. instance_prior = np.histogram(votes_vec, range(C))[0] / 100. / np.histogram(votes_vec, range(C))[0].sum()
  20. # CLLDA with all transforms
  21. cls = concurrent_cllda(4, votes_vec, votes_vec_workers, votes_vec_instances, nprocs=4,
  22. worker_prior=all_priors, instance_prior=instance_prior,
  23. transform=('none', 'ilr', 'clr', 'alr'), num_epochs=1000, burn_in=200)
  24. # combine models
  25. cl = combine_cllda(cls)
  26. # CLLDA with all transforms weak
  27. all_priors_weak = np.tile(np.maximum(np.hstack((np.eye(T), np.zeros((T, 1)))), 0.01), [A, 1, 1])
  28. cls_weak = concurrent_cllda(4, votes_vec, votes_vec_workers, votes_vec_instances, nprocs=4,
  29. worker_prior=all_priors_weak, instance_prior=instance_prior,
  30. transform=('none', 'ilr', 'clr', 'alr'), num_epochs=1000, burn_in=200)
  31. cl_weak = combine_cllda(cls_weak)
  32. # MV and DS and CLLDA
  33. from crowd_labeling import MV
  34. from crowd_labeling import DS
  35. # ignoring "?"
  36. ind = votes_vec != 7
  37. temp_votes_vec = votes_vec[ind]
  38. temp_votes_vec_workers = votes_vec_workers[ind]
  39. temp_votes_vec_instances = votes_vec_instances[ind]
  40. cls_ignore = concurrent_cllda(4, temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances, nprocs=4,
  41. worker_prior=all_priors, instance_prior=instance_prior,
  42. transform=('none', 'ilr', 'clr', 'alr'), num_epochs=1000, burn_in=200)
  43. cl_ignore = combine_cllda(cls_ignore)
  44. _, temp_votes_vec_workers = np.unique(temp_votes_vec_workers, return_inverse=True)
  45. _, temp_votes_vec_instances = np.unique(temp_votes_vec_instances, return_inverse=True)
  46. mv_ignore = MV(temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances)
  47. ds_ignore = DS(temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances)
  48. # removing labels with "?"
  49. ind = votes_vec == 7
  50. to_remove = np.stack((votes_vec_workers[ind], votes_vec_instances[ind])).T
  51. ind = np.ones_like(votes_vec, dtype=bool)
  52. for it, vote in enumerate(np.stack((votes_vec_workers, votes_vec_instances)).T):
  53. if (vote == to_remove).all(1).any():
  54. ind[it] = False
  55. temp_votes_vec = votes_vec[ind]
  56. temp_votes_vec_workers = votes_vec_workers[ind]
  57. temp_votes_vec_instances = votes_vec_instances[ind]
  58. _, temp_votes_vec_workers = np.unique(temp_votes_vec_workers, return_inverse=True)
  59. _, temp_votes_vec_instances = np.unique(temp_votes_vec_instances, return_inverse=True)
  60. mv_remove = MV(temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances)
  61. ds_remove = DS(temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances)
  62. cls_remove = concurrent_cllda(4, temp_votes_vec, temp_votes_vec_workers, temp_votes_vec_instances, nprocs=4,
  63. worker_prior=all_priors, instance_prior=instance_prior,
  64. transform=('none', 'ilr', 'clr', 'alr'), num_epochs=1000, burn_in=200)
  65. cl_remove = combine_cllda(cls_remove)
  66. # results to save
  67. save = dict()
  68. save['instance_labels'] = cl.labels[0]
  69. save['instance_labels_ilr'] = cl.labels[1]
  70. save['instance_labels_clr'] = cl.labels[2]
  71. save['instance_labels_alr'] = cl.labels[3]
  72. save['instance_label_cov'] = cl.labels_cov[0]
  73. save['instance_label_cov_ilr'] = cl.labels_cov[1]
  74. save['instance_label_cov_clr'] = cl.labels_cov[2]
  75. save['instance_label_cov_alr'] = cl.labels_cov[3]
  76. save['instance_id'] = cl.instance_ids
  77. save['instance_number'] = votes_vec_instances
  78. save['instance_study_numbers'] = instance_study_numbers
  79. save['instance_set_numbers'] = instance_set_numbers
  80. save['instance_ic_numbers'] = instance_ic_numbers
  81. save['raw_instances'] = votes_vec_instances
  82. save['raw_workers'] = votes_vec_workers
  83. save['raw_votes'] = votes_vec
  84. save['worker_mats'] = cl.worker_mats
  85. save['worker_prior'] = all_priors[0]
  86. save['instance_prior'] = instance_prior
  87. save['num_epoch'] = 1000
  88. save['burn_in'] = 200
  89. # save
  90. with open('ICLabels_test.pkl', 'wb') as f:
  91. pkl.dump(save, f)
  92. savemat('ICLabels_test.mat', save, oned_as='column')