|
@@ -0,0 +1,221 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# Time-stamp: "2024-05-03 10:33:30 (ywatanabe)"
|
|
|
+
|
|
|
+import sys
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.optim as optim
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from torchvision import datasets, transforms
|
|
|
+
|
|
|
+sys.path.append(".")
|
|
|
+from scripts.PerceptronOrINN import PerceptronOrINN
|
|
|
+
|
|
|
+
|
|
|
+def set_random_seeds(seed=42):
|
|
|
+ torch.manual_seed(seed)
|
|
|
+ torch.cuda.manual_seed(seed)
|
|
|
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
|
+ torch.backends.cudnn.benchmark = False
|
|
|
+ torch.backends.cudnn.deterministic = True
|
|
|
+
|
|
|
+
|
|
|
+def main(act_str):
|
|
|
+ # Initialize config
|
|
|
+ model_config = {
|
|
|
+ "act_str": act_str,
|
|
|
+ "do_resample_act_funcs": False,
|
|
|
+ "bs": 64,
|
|
|
+ "n_fc_in": 784,
|
|
|
+ "n_fc_1": 1000,
|
|
|
+ "n_fc_2": 1000,
|
|
|
+ "d_ratio_1": 0.5,
|
|
|
+ "sigmoid_beta_0_mean": 1,
|
|
|
+ "sigmoid_beta_0_var": 0,
|
|
|
+ "sigmoid_beta_1_mean": 0,
|
|
|
+ "sigmoid_beta_1_var": 0,
|
|
|
+ "intestine_simulated_beta_0_mean": 3.06,
|
|
|
+ "intestine_simulated_beta_0_var": 1.38,
|
|
|
+ "intestine_simulated_beta_1_mean": 0,
|
|
|
+ "intestine_simulated_beta_1_var": 3.23,
|
|
|
+ "LABELS": list(range(10)),
|
|
|
+ }
|
|
|
+
|
|
|
+ # Initialize the model, optimizer, and loss function
|
|
|
+ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ model = PerceptronOrINN(model_config).to(device)
|
|
|
+ optimizer = optim.SGD(model.parameters(), lr=0.001)
|
|
|
+ criterion = torch.nn.CrossEntropyLoss()
|
|
|
+
|
|
|
+ # Load MNIST Data
|
|
|
+ transform = transforms.Compose(
|
|
|
+ [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
|
|
|
+ )
|
|
|
+
|
|
|
+ train_dataset = datasets.MNIST(
|
|
|
+ "./data", train=True, download=True, transform=transform
|
|
|
+ )
|
|
|
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
|
+
|
|
|
+ test_dataset = datasets.MNIST(
|
|
|
+ "./data", train=False, download=True, transform=transform
|
|
|
+ )
|
|
|
+ test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
|
|
|
+
|
|
|
+ # Training Loop
|
|
|
+ for epoch in range(10):
|
|
|
+ for batch_idx, (data, target) in enumerate(train_loader):
|
|
|
+ data, target = data.to(device), target.to(device)
|
|
|
+
|
|
|
+ data = data.view(data.size(0), -1)
|
|
|
+
|
|
|
+ optimizer.zero_grad()
|
|
|
+ output = model(data)
|
|
|
+
|
|
|
+ loss = criterion(output, target)
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ if batch_idx % 100 == 0:
|
|
|
+ print(f"Epoch: {epoch} Batch: {batch_idx} Loss: {loss.item()}")
|
|
|
+
|
|
|
+ # Test Loop
|
|
|
+ correct = 0
|
|
|
+ total = 0
|
|
|
+ with torch.no_grad():
|
|
|
+ for data, target in test_loader:
|
|
|
+ data, target = data.to(device), target.to(device)
|
|
|
+ data = data.view(data.size(0), -1)
|
|
|
+ output = model(data)
|
|
|
+ _, predicted = torch.max(output.data, 1)
|
|
|
+ total += target.size(0)
|
|
|
+ correct += (predicted == target).sum().item()
|
|
|
+
|
|
|
+ print(f"Accuracy: {(100 * correct / total):.2f}%")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Switcher for Intestine-derived Neural Network (INN) or a perceptron"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--activation_function",
|
|
|
+ type=str,
|
|
|
+ default="intestine_simulated",
|
|
|
+ choices=["intestine_simulated", "sigmoid"],
|
|
|
+ help="The type of activation function to use in the model",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # Set the random seed for reproducibility
|
|
|
+ set_random_seeds()
|
|
|
+
|
|
|
+ main(args.activation_function)
|
|
|
+
|
|
|
+ """
|
|
|
+ Epoch: 0 Batch: 0 Loss: 2.337707996368408
|
|
|
+ Epoch: 0 Batch: 100 Loss: 2.3034439086914062
|
|
|
+ Epoch: 0 Batch: 200 Loss: 2.308676242828369
|
|
|
+ Epoch: 0 Batch: 300 Loss: 2.291069269180298
|
|
|
+ Epoch: 0 Batch: 400 Loss: 2.2898669242858887
|
|
|
+ Epoch: 0 Batch: 500 Loss: 2.259193181991577
|
|
|
+ Epoch: 0 Batch: 600 Loss: 2.2262508869171143
|
|
|
+ Epoch: 0 Batch: 700 Loss: 2.2591052055358887
|
|
|
+ Epoch: 0 Batch: 800 Loss: 2.227748155593872
|
|
|
+ Epoch: 0 Batch: 900 Loss: 2.214931011199951
|
|
|
+ Epoch: 1 Batch: 0 Loss: 2.205514669418335
|
|
|
+ Epoch: 1 Batch: 100 Loss: 2.1726245880126953
|
|
|
+ Epoch: 1 Batch: 200 Loss: 2.179299831390381
|
|
|
+ Epoch: 1 Batch: 300 Loss: 2.1493451595306396
|
|
|
+ Epoch: 1 Batch: 400 Loss: 2.181206703186035
|
|
|
+ Epoch: 1 Batch: 500 Loss: 2.165060043334961
|
|
|
+ Epoch: 1 Batch: 600 Loss: 2.1518664360046387
|
|
|
+ Epoch: 1 Batch: 700 Loss: 2.107090950012207
|
|
|
+ Epoch: 1 Batch: 800 Loss: 2.0922796726226807
|
|
|
+ Epoch: 1 Batch: 900 Loss: 2.0883922576904297
|
|
|
+ Epoch: 2 Batch: 0 Loss: 2.067126750946045
|
|
|
+ Epoch: 2 Batch: 100 Loss: 2.0309972763061523
|
|
|
+ Epoch: 2 Batch: 200 Loss: 2.0337514877319336
|
|
|
+ Epoch: 2 Batch: 300 Loss: 2.038759231567383
|
|
|
+ Epoch: 2 Batch: 400 Loss: 1.9987781047821045
|
|
|
+ Epoch: 2 Batch: 500 Loss: 2.0336506366729736
|
|
|
+ Epoch: 2 Batch: 600 Loss: 1.9891915321350098
|
|
|
+ Epoch: 2 Batch: 700 Loss: 1.9620522260665894
|
|
|
+ Epoch: 2 Batch: 800 Loss: 1.9142768383026123
|
|
|
+ Epoch: 2 Batch: 900 Loss: 1.885387659072876
|
|
|
+ Epoch: 3 Batch: 0 Loss: 1.8634549379348755
|
|
|
+ Epoch: 3 Batch: 100 Loss: 1.922452449798584
|
|
|
+ Epoch: 3 Batch: 200 Loss: 1.843092918395996
|
|
|
+ Epoch: 3 Batch: 300 Loss: 1.8004738092422485
|
|
|
+ Epoch: 3 Batch: 400 Loss: 1.781233549118042
|
|
|
+ Epoch: 3 Batch: 500 Loss: 1.7627004384994507
|
|
|
+ Epoch: 3 Batch: 600 Loss: 1.6683433055877686
|
|
|
+ Epoch: 3 Batch: 700 Loss: 1.7323194742202759
|
|
|
+ Epoch: 3 Batch: 800 Loss: 1.638037919998169
|
|
|
+ Epoch: 3 Batch: 900 Loss: 1.6239663362503052
|
|
|
+ Epoch: 4 Batch: 0 Loss: 1.5985149145126343
|
|
|
+ Epoch: 4 Batch: 100 Loss: 1.5736466646194458
|
|
|
+ Epoch: 4 Batch: 200 Loss: 1.5068103075027466
|
|
|
+ Epoch: 4 Batch: 300 Loss: 1.3710484504699707
|
|
|
+ Epoch: 4 Batch: 400 Loss: 1.3616474866867065
|
|
|
+ Epoch: 4 Batch: 500 Loss: 1.401007890701294
|
|
|
+ Epoch: 4 Batch: 600 Loss: 1.426381230354309
|
|
|
+ Epoch: 4 Batch: 700 Loss: 1.313161849975586
|
|
|
+ Epoch: 4 Batch: 800 Loss: 1.3480956554412842
|
|
|
+ Epoch: 4 Batch: 900 Loss: 1.3751581907272339
|
|
|
+ Epoch: 5 Batch: 0 Loss: 1.3140125274658203
|
|
|
+ Epoch: 5 Batch: 100 Loss: 1.1920424699783325
|
|
|
+ Epoch: 5 Batch: 200 Loss: 1.2809444665908813
|
|
|
+ Epoch: 5 Batch: 300 Loss: 1.317400574684143
|
|
|
+ Epoch: 5 Batch: 400 Loss: 1.1676445007324219
|
|
|
+ Epoch: 5 Batch: 500 Loss: 1.2212748527526855
|
|
|
+ Epoch: 5 Batch: 600 Loss: 1.1691396236419678
|
|
|
+ Epoch: 5 Batch: 700 Loss: 1.1782811880111694
|
|
|
+ Epoch: 5 Batch: 800 Loss: 1.243850827217102
|
|
|
+ Epoch: 5 Batch: 900 Loss: 1.0820790529251099
|
|
|
+ Epoch: 6 Batch: 0 Loss: 1.176945686340332
|
|
|
+ Epoch: 6 Batch: 100 Loss: 1.1030327081680298
|
|
|
+ Epoch: 6 Batch: 200 Loss: 1.183580756187439
|
|
|
+ Epoch: 6 Batch: 300 Loss: 1.0883508920669556
|
|
|
+ Epoch: 6 Batch: 400 Loss: 0.9526631832122803
|
|
|
+ Epoch: 6 Batch: 500 Loss: 0.922423243522644
|
|
|
+ Epoch: 6 Batch: 600 Loss: 1.0819437503814697
|
|
|
+ Epoch: 6 Batch: 700 Loss: 0.939717173576355
|
|
|
+ Epoch: 6 Batch: 800 Loss: 1.0133917331695557
|
|
|
+ Epoch: 6 Batch: 900 Loss: 0.9692772030830383
|
|
|
+ Epoch: 7 Batch: 0 Loss: 0.9215019941329956
|
|
|
+ Epoch: 7 Batch: 100 Loss: 0.963954746723175
|
|
|
+ Epoch: 7 Batch: 200 Loss: 0.9186135530471802
|
|
|
+ Epoch: 7 Batch: 300 Loss: 0.8597159385681152
|
|
|
+ Epoch: 7 Batch: 400 Loss: 1.0357908010482788
|
|
|
+ Epoch: 7 Batch: 500 Loss: 0.9571436047554016
|
|
|
+ Epoch: 7 Batch: 600 Loss: 0.9383936524391174
|
|
|
+ Epoch: 7 Batch: 700 Loss: 0.8021243810653687
|
|
|
+ Epoch: 7 Batch: 800 Loss: 0.8582736849784851
|
|
|
+ Epoch: 7 Batch: 900 Loss: 0.8480632901191711
|
|
|
+ Epoch: 8 Batch: 0 Loss: 0.752300500869751
|
|
|
+ Epoch: 8 Batch: 100 Loss: 0.9244712591171265
|
|
|
+ Epoch: 8 Batch: 200 Loss: 0.8200180530548096
|
|
|
+ Epoch: 8 Batch: 300 Loss: 0.8038215041160583
|
|
|
+ Epoch: 8 Batch: 400 Loss: 0.8257690668106079
|
|
|
+ Epoch: 8 Batch: 500 Loss: 0.7846906185150146
|
|
|
+ Epoch: 8 Batch: 600 Loss: 0.6740760207176208
|
|
|
+ Epoch: 8 Batch: 700 Loss: 0.6744505763053894
|
|
|
+ Epoch: 8 Batch: 800 Loss: 0.6676566004753113
|
|
|
+ Epoch: 8 Batch: 900 Loss: 0.6961811184883118
|
|
|
+ Epoch: 9 Batch: 0 Loss: 0.8140008449554443
|
|
|
+ Epoch: 9 Batch: 100 Loss: 0.6493793725967407
|
|
|
+ Epoch: 9 Batch: 200 Loss: 0.6754528880119324
|
|
|
+ Epoch: 9 Batch: 300 Loss: 0.7368165254592896
|
|
|
+ Epoch: 9 Batch: 400 Loss: 0.7969047427177429
|
|
|
+ Epoch: 9 Batch: 500 Loss: 0.6278544664382935
|
|
|
+ Epoch: 9 Batch: 600 Loss: 0.6742461323738098
|
|
|
+ Epoch: 9 Batch: 700 Loss: 0.6166834235191345
|
|
|
+ Epoch: 9 Batch: 800 Loss: 0.7250018119812012
|
|
|
+ Epoch: 9 Batch: 900 Loss: 0.717006266117096
|
|
|
+ Accuracy: 83.51%
|
|
|
+ """
|