train_MNIST.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Time-stamp: "2024-05-03 10:33:30 (ywatanabe)"
  4. import sys
  5. import torch
  6. import torch.optim as optim
  7. from torch.utils.data import DataLoader
  8. from torchvision import datasets, transforms
  9. sys.path.append(".")
  10. from scripts.PerceptronOrINN import PerceptronOrINN
  11. def set_random_seeds(seed=42):
  12. torch.manual_seed(seed)
  13. torch.cuda.manual_seed(seed)
  14. torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
  15. torch.backends.cudnn.benchmark = False
  16. torch.backends.cudnn.deterministic = True
  17. def main(act_str):
  18. # Initialize config
  19. model_config = {
  20. "act_str": act_str,
  21. "do_resample_act_funcs": False,
  22. "bs": 64,
  23. "n_fc_in": 784,
  24. "n_fc_1": 1000,
  25. "n_fc_2": 1000,
  26. "d_ratio_1": 0.5,
  27. "sigmoid_beta_0_mean": 1,
  28. "sigmoid_beta_0_var": 0,
  29. "sigmoid_beta_1_mean": 0,
  30. "sigmoid_beta_1_var": 0,
  31. "intestine_simulated_beta_0_mean": 3.06,
  32. "intestine_simulated_beta_0_var": 1.38,
  33. "intestine_simulated_beta_1_mean": 0,
  34. "intestine_simulated_beta_1_var": 3.23,
  35. "LABELS": list(range(10)),
  36. }
  37. # Initialize the model, optimizer, and loss function
  38. device = "cuda" if torch.cuda.is_available() else "cpu"
  39. model = PerceptronOrINN(model_config).to(device)
  40. optimizer = optim.SGD(model.parameters(), lr=0.001)
  41. criterion = torch.nn.CrossEntropyLoss()
  42. # Load MNIST Data
  43. transform = transforms.Compose(
  44. [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
  45. )
  46. train_dataset = datasets.MNIST(
  47. "./data", train=True, download=True, transform=transform
  48. )
  49. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  50. test_dataset = datasets.MNIST(
  51. "./data", train=False, download=True, transform=transform
  52. )
  53. test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
  54. # Training Loop
  55. for epoch in range(10):
  56. for batch_idx, (data, target) in enumerate(train_loader):
  57. data, target = data.to(device), target.to(device)
  58. data = data.view(data.size(0), -1)
  59. optimizer.zero_grad()
  60. output = model(data)
  61. loss = criterion(output, target)
  62. loss.backward()
  63. optimizer.step()
  64. if batch_idx % 100 == 0:
  65. print(f"Epoch: {epoch} Batch: {batch_idx} Loss: {loss.item()}")
  66. # Test Loop
  67. correct = 0
  68. total = 0
  69. with torch.no_grad():
  70. for data, target in test_loader:
  71. data, target = data.to(device), target.to(device)
  72. data = data.view(data.size(0), -1)
  73. output = model(data)
  74. _, predicted = torch.max(output.data, 1)
  75. total += target.size(0)
  76. correct += (predicted == target).sum().item()
  77. print(f"Accuracy: {(100 * correct / total):.2f}%")
  78. if __name__ == "__main__":
  79. import argparse
  80. parser = argparse.ArgumentParser(
  81. description="Switcher for Intestine-derived Neural Network (INN) or a perceptron"
  82. )
  83. parser.add_argument(
  84. "--activation_function",
  85. type=str,
  86. default="intestine_simulated",
  87. choices=["intestine_simulated", "sigmoid"],
  88. help="The type of activation function to use in the model",
  89. )
  90. args = parser.parse_args()
  91. # Set the random seed for reproducibility
  92. set_random_seeds()
  93. main(args.activation_function)
  94. """
  95. Epoch: 0 Batch: 0 Loss: 2.337707996368408
  96. Epoch: 0 Batch: 100 Loss: 2.3034439086914062
  97. Epoch: 0 Batch: 200 Loss: 2.308676242828369
  98. Epoch: 0 Batch: 300 Loss: 2.291069269180298
  99. Epoch: 0 Batch: 400 Loss: 2.2898669242858887
  100. Epoch: 0 Batch: 500 Loss: 2.259193181991577
  101. Epoch: 0 Batch: 600 Loss: 2.2262508869171143
  102. Epoch: 0 Batch: 700 Loss: 2.2591052055358887
  103. Epoch: 0 Batch: 800 Loss: 2.227748155593872
  104. Epoch: 0 Batch: 900 Loss: 2.214931011199951
  105. Epoch: 1 Batch: 0 Loss: 2.205514669418335
  106. Epoch: 1 Batch: 100 Loss: 2.1726245880126953
  107. Epoch: 1 Batch: 200 Loss: 2.179299831390381
  108. Epoch: 1 Batch: 300 Loss: 2.1493451595306396
  109. Epoch: 1 Batch: 400 Loss: 2.181206703186035
  110. Epoch: 1 Batch: 500 Loss: 2.165060043334961
  111. Epoch: 1 Batch: 600 Loss: 2.1518664360046387
  112. Epoch: 1 Batch: 700 Loss: 2.107090950012207
  113. Epoch: 1 Batch: 800 Loss: 2.0922796726226807
  114. Epoch: 1 Batch: 900 Loss: 2.0883922576904297
  115. Epoch: 2 Batch: 0 Loss: 2.067126750946045
  116. Epoch: 2 Batch: 100 Loss: 2.0309972763061523
  117. Epoch: 2 Batch: 200 Loss: 2.0337514877319336
  118. Epoch: 2 Batch: 300 Loss: 2.038759231567383
  119. Epoch: 2 Batch: 400 Loss: 1.9987781047821045
  120. Epoch: 2 Batch: 500 Loss: 2.0336506366729736
  121. Epoch: 2 Batch: 600 Loss: 1.9891915321350098
  122. Epoch: 2 Batch: 700 Loss: 1.9620522260665894
  123. Epoch: 2 Batch: 800 Loss: 1.9142768383026123
  124. Epoch: 2 Batch: 900 Loss: 1.885387659072876
  125. Epoch: 3 Batch: 0 Loss: 1.8634549379348755
  126. Epoch: 3 Batch: 100 Loss: 1.922452449798584
  127. Epoch: 3 Batch: 200 Loss: 1.843092918395996
  128. Epoch: 3 Batch: 300 Loss: 1.8004738092422485
  129. Epoch: 3 Batch: 400 Loss: 1.781233549118042
  130. Epoch: 3 Batch: 500 Loss: 1.7627004384994507
  131. Epoch: 3 Batch: 600 Loss: 1.6683433055877686
  132. Epoch: 3 Batch: 700 Loss: 1.7323194742202759
  133. Epoch: 3 Batch: 800 Loss: 1.638037919998169
  134. Epoch: 3 Batch: 900 Loss: 1.6239663362503052
  135. Epoch: 4 Batch: 0 Loss: 1.5985149145126343
  136. Epoch: 4 Batch: 100 Loss: 1.5736466646194458
  137. Epoch: 4 Batch: 200 Loss: 1.5068103075027466
  138. Epoch: 4 Batch: 300 Loss: 1.3710484504699707
  139. Epoch: 4 Batch: 400 Loss: 1.3616474866867065
  140. Epoch: 4 Batch: 500 Loss: 1.401007890701294
  141. Epoch: 4 Batch: 600 Loss: 1.426381230354309
  142. Epoch: 4 Batch: 700 Loss: 1.313161849975586
  143. Epoch: 4 Batch: 800 Loss: 1.3480956554412842
  144. Epoch: 4 Batch: 900 Loss: 1.3751581907272339
  145. Epoch: 5 Batch: 0 Loss: 1.3140125274658203
  146. Epoch: 5 Batch: 100 Loss: 1.1920424699783325
  147. Epoch: 5 Batch: 200 Loss: 1.2809444665908813
  148. Epoch: 5 Batch: 300 Loss: 1.317400574684143
  149. Epoch: 5 Batch: 400 Loss: 1.1676445007324219
  150. Epoch: 5 Batch: 500 Loss: 1.2212748527526855
  151. Epoch: 5 Batch: 600 Loss: 1.1691396236419678
  152. Epoch: 5 Batch: 700 Loss: 1.1782811880111694
  153. Epoch: 5 Batch: 800 Loss: 1.243850827217102
  154. Epoch: 5 Batch: 900 Loss: 1.0820790529251099
  155. Epoch: 6 Batch: 0 Loss: 1.176945686340332
  156. Epoch: 6 Batch: 100 Loss: 1.1030327081680298
  157. Epoch: 6 Batch: 200 Loss: 1.183580756187439
  158. Epoch: 6 Batch: 300 Loss: 1.0883508920669556
  159. Epoch: 6 Batch: 400 Loss: 0.9526631832122803
  160. Epoch: 6 Batch: 500 Loss: 0.922423243522644
  161. Epoch: 6 Batch: 600 Loss: 1.0819437503814697
  162. Epoch: 6 Batch: 700 Loss: 0.939717173576355
  163. Epoch: 6 Batch: 800 Loss: 1.0133917331695557
  164. Epoch: 6 Batch: 900 Loss: 0.9692772030830383
  165. Epoch: 7 Batch: 0 Loss: 0.9215019941329956
  166. Epoch: 7 Batch: 100 Loss: 0.963954746723175
  167. Epoch: 7 Batch: 200 Loss: 0.9186135530471802
  168. Epoch: 7 Batch: 300 Loss: 0.8597159385681152
  169. Epoch: 7 Batch: 400 Loss: 1.0357908010482788
  170. Epoch: 7 Batch: 500 Loss: 0.9571436047554016
  171. Epoch: 7 Batch: 600 Loss: 0.9383936524391174
  172. Epoch: 7 Batch: 700 Loss: 0.8021243810653687
  173. Epoch: 7 Batch: 800 Loss: 0.8582736849784851
  174. Epoch: 7 Batch: 900 Loss: 0.8480632901191711
  175. Epoch: 8 Batch: 0 Loss: 0.752300500869751
  176. Epoch: 8 Batch: 100 Loss: 0.9244712591171265
  177. Epoch: 8 Batch: 200 Loss: 0.8200180530548096
  178. Epoch: 8 Batch: 300 Loss: 0.8038215041160583
  179. Epoch: 8 Batch: 400 Loss: 0.8257690668106079
  180. Epoch: 8 Batch: 500 Loss: 0.7846906185150146
  181. Epoch: 8 Batch: 600 Loss: 0.6740760207176208
  182. Epoch: 8 Batch: 700 Loss: 0.6744505763053894
  183. Epoch: 8 Batch: 800 Loss: 0.6676566004753113
  184. Epoch: 8 Batch: 900 Loss: 0.6961811184883118
  185. Epoch: 9 Batch: 0 Loss: 0.8140008449554443
  186. Epoch: 9 Batch: 100 Loss: 0.6493793725967407
  187. Epoch: 9 Batch: 200 Loss: 0.6754528880119324
  188. Epoch: 9 Batch: 300 Loss: 0.7368165254592896
  189. Epoch: 9 Batch: 400 Loss: 0.7969047427177429
  190. Epoch: 9 Batch: 500 Loss: 0.6278544664382935
  191. Epoch: 9 Batch: 600 Loss: 0.6742461323738098
  192. Epoch: 9 Batch: 700 Loss: 0.6166834235191345
  193. Epoch: 9 Batch: 800 Loss: 0.7250018119812012
  194. Epoch: 9 Batch: 900 Loss: 0.717006266117096
  195. Accuracy: 83.51%
  196. """