MNIST
Since QKAN (KAN-family model) can be viewed as a universal approximator, it can directly replace the job of MLP with less number of parameters.
Here we are going to demonstrate how to use QKAN to do the classification task on MNIST dataset.
[1]:
import random
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from qkan import QKAN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[2]:
class CNet(nn.Module):
def __init__(self, device):
super(CNet, self).__init__()
self.device = device
self.cnn1 = nn.Conv2d(
in_channels=1, out_channels=3, kernel_size=5, device=device
)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.cnn2 = nn.Conv2d(
in_channels=3, out_channels=6, kernel_size=5, device=device
)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.cnn3 = nn.Conv2d(
in_channels=6, out_channels=10, kernel_size=3, device=device
)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = x.to(self.device)
x = self.cnn1(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.cnn2(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.cnn3(x)
x = self.relu(x)
x = self.maxpool3(x)
x = x.flatten(start_dim=1)
return x
[3]:
# Load MNIST
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=1000, shuffle=True)
testloader = DataLoader(testset, batch_size=1000, shuffle=False)
[4]:
model = nn.Sequential(
CNet(device=device),
QKAN([10, 10], device=device),
)
criterion = nn.CrossEntropyLoss()
accs = []
losses = []
test_accs = []
test_losses = []
test_top3_accs = []
optimizer = optim.Adam(model.parameters(), lr=1e-3)
[5]:
for epoch in range(20):
# Train
model.train()
with tqdm(trainloader) as pbar:
for i, (images, labels) in enumerate(pbar):
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
accs.append(accuracy.item())
losses.append(loss.item())
pbar.set_postfix(
loss=loss.item(),
accuracy=accuracy.item(),
lr=optimizer.param_groups[0]["lr"],
)
# test
model.eval()
test_loss = 0
test_accuracy = 0
test_top3_accuracy = 0
with torch.no_grad():
for images, labels in testloader:
output = model(images)
test_loss += criterion(output, labels.to(device)).item()
test_accuracy += (
(output.argmax(dim=1) == labels.to(device)).float().mean().item()
)
test_top3_accuracy += (
(output.topk(3, dim=1).indices == labels.to(device).unsqueeze(1)).any(dim=1).float().mean().item()
)
test_loss /= len(testloader)
test_accuracy /= len(testloader)
test_top3_accuracy /= len(testloader)
print(f"Epoch {epoch + 1}, test Loss: {test_loss}, test Accuracy: {test_accuracy}")
test_accs.append(test_accuracy)
test_losses.append(test_loss)
test_top3_accs.append(test_top3_accuracy)
100%|██████████| 60/60 [00:01<00:00, 31.79it/s, accuracy=0.506, loss=1.96, lr=0.001]
Epoch 1, test Loss: 1.9495182633399963, test Accuracy: 0.5352000236511231
100%|██████████| 60/60 [00:01<00:00, 36.10it/s, accuracy=0.702, loss=1.45, lr=0.001]
Epoch 2, test Loss: 1.4130778551101684, test Accuracy: 0.7213000357151031
100%|██████████| 60/60 [00:01<00:00, 35.75it/s, accuracy=0.804, loss=1.06, lr=0.001]
Epoch 3, test Loss: 1.0548779726028443, test Accuracy: 0.7998000502586364
100%|██████████| 60/60 [00:01<00:00, 34.35it/s, accuracy=0.842, loss=0.831, lr=0.001]
Epoch 4, test Loss: 0.8044245541095734, test Accuracy: 0.8453000366687775
100%|██████████| 60/60 [00:01<00:00, 35.49it/s, accuracy=0.86, loss=0.666, lr=0.001]
Epoch 5, test Loss: 0.6300603926181794, test Accuracy: 0.8797000527381897
100%|██████████| 60/60 [00:01<00:00, 35.79it/s, accuracy=0.908, loss=0.518, lr=0.001]
Epoch 6, test Loss: 0.5075116276741027, test Accuracy: 0.903400045633316
100%|██████████| 60/60 [00:01<00:00, 34.23it/s, accuracy=0.895, loss=0.466, lr=0.001]
Epoch 7, test Loss: 0.42405287027359007, test Accuracy: 0.9184000372886658
100%|██████████| 60/60 [00:01<00:00, 35.56it/s, accuracy=0.916, loss=0.376, lr=0.001]
Epoch 8, test Loss: 0.3646392196416855, test Accuracy: 0.9274000406265259
100%|██████████| 60/60 [00:01<00:00, 35.60it/s, accuracy=0.929, loss=0.353, lr=0.001]
Epoch 9, test Loss: 0.32189558148384095, test Accuracy: 0.9332000494003296
100%|██████████| 60/60 [00:01<00:00, 34.16it/s, accuracy=0.942, loss=0.284, lr=0.001]
Epoch 10, test Loss: 0.2885587066411972, test Accuracy: 0.9390000462532043
100%|██████████| 60/60 [00:01<00:00, 35.40it/s, accuracy=0.946, loss=0.274, lr=0.001]
Epoch 11, test Loss: 0.265482497215271, test Accuracy: 0.9436000525951386
100%|██████████| 60/60 [00:01<00:00, 35.32it/s, accuracy=0.926, loss=0.291, lr=0.001]
Epoch 12, test Loss: 0.24529014378786088, test Accuracy: 0.9458000421524048
100%|██████████| 60/60 [00:01<00:00, 34.31it/s, accuracy=0.947, loss=0.233, lr=0.001]
Epoch 13, test Loss: 0.22940405011177062, test Accuracy: 0.9485000431537628
100%|██████████| 60/60 [00:01<00:00, 35.25it/s, accuracy=0.953, loss=0.215, lr=0.001]
Epoch 14, test Loss: 0.21600559800863267, test Accuracy: 0.9513000428676606
100%|██████████| 60/60 [00:01<00:00, 35.53it/s, accuracy=0.948, loss=0.226, lr=0.001]
Epoch 15, test Loss: 0.2052382320165634, test Accuracy: 0.952800041437149
100%|██████████| 60/60 [00:01<00:00, 34.30it/s, accuracy=0.95, loss=0.205, lr=0.001]
Epoch 16, test Loss: 0.1961878292262554, test Accuracy: 0.9533000528812409
100%|██████████| 60/60 [00:01<00:00, 35.69it/s, accuracy=0.96, loss=0.176, lr=0.001]
Epoch 17, test Loss: 0.1896478660404682, test Accuracy: 0.9549000442028046
100%|██████████| 60/60 [00:01<00:00, 36.09it/s, accuracy=0.953, loss=0.199, lr=0.001]
Epoch 18, test Loss: 0.18260392621159555, test Accuracy: 0.9564000487327575
100%|██████████| 60/60 [00:01<00:00, 34.97it/s, accuracy=0.968, loss=0.177, lr=0.001]
Epoch 19, test Loss: 0.17513826712965966, test Accuracy: 0.9577000498771667
100%|██████████| 60/60 [00:01<00:00, 36.06it/s, accuracy=0.954, loss=0.17, lr=0.001]
Epoch 20, test Loss: 0.16960925087332726, test Accuracy: 0.9585000455379487
[6]:
print("Top-1 Accuracy: ", max(test_accs))
print("Top-3 Accuracy: ", max(test_top3_accs))
print("CNN+QKAN Model size:", len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))
Top-1 Accuracy: 0.9585000455379487
Top-3 Accuracy: 0.9934000551700592
CNN+QKAN Model size: 1884
Compare to MLP
[7]:
class CNN(nn.Module):
def __init__(self, device):
super(CNN, self).__init__()
self.cnet = CNet(device)
self.fc1 = nn.Linear(10, 20).to(device)
self.fc2 = nn.Linear(20, 20).to(device)
self.fc3 = nn.Linear(20, 10).to(device)
self.device = device
def forward(self, x):
x.to(self.device)
x = self.cnet(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
model = CNN(device=device)
criterion = nn.CrossEntropyLoss()
accs = []
losses = []
test_accs = []
test_losses = []
test_top3_accs = []
optimizer = optim.Adam(model.parameters(), lr=1e-3)
[8]:
for epoch in range(20):
# Train
model.train()
with tqdm(trainloader) as pbar:
for i, (images, labels) in enumerate(pbar):
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
accs.append(accuracy.item())
losses.append(loss.item())
pbar.set_postfix(
loss=loss.item(),
accuracy=accuracy.item(),
lr=optimizer.param_groups[0]["lr"],
)
# test
model.eval()
test_loss = 0
test_accuracy = 0
test_top3_accuracy = 0
with torch.no_grad():
for images, labels in testloader:
output = model(images)
test_loss += criterion(output, labels.to(device)).item()
test_accuracy += (
(output.argmax(dim=1) == labels.to(device)).float().mean().item()
)
test_top3_accuracy += (
(output.topk(3, dim=1).indices == labels.to(device).unsqueeze(1)).any(dim=1).float().mean().item()
)
test_loss /= len(testloader)
test_accuracy /= len(testloader)
test_top3_accuracy /= len(testloader)
print(f"Epoch {epoch + 1}, test Loss: {test_loss}, test Accuracy: {test_accuracy}")
test_accs.append(test_accuracy)
test_losses.append(test_loss)
test_top3_accs.append(test_top3_accuracy)
100%|██████████| 60/60 [00:01<00:00, 38.35it/s, accuracy=0.254, loss=2.07, lr=0.001]
Epoch 1, test Loss: 2.06544942855835, test Accuracy: 0.24110001176595688
100%|██████████| 60/60 [00:01<00:00, 37.28it/s, accuracy=0.652, loss=1.04, lr=0.001]
Epoch 2, test Loss: 1.0108458638191222, test Accuracy: 0.6691000401973725
100%|██████████| 60/60 [00:01<00:00, 38.37it/s, accuracy=0.829, loss=0.552, lr=0.001]
Epoch 3, test Loss: 0.5260033041238785, test Accuracy: 0.8397000432014465
100%|██████████| 60/60 [00:01<00:00, 38.19it/s, accuracy=0.88, loss=0.436, lr=0.001]
Epoch 4, test Loss: 0.39268899857997897, test Accuracy: 0.8824000537395478
100%|██████████| 60/60 [00:01<00:00, 37.21it/s, accuracy=0.9, loss=0.317, lr=0.001]
Epoch 5, test Loss: 0.3284891590476036, test Accuracy: 0.9017000377178193
100%|██████████| 60/60 [00:01<00:00, 38.31it/s, accuracy=0.914, loss=0.309, lr=0.001]
Epoch 6, test Loss: 0.2903182566165924, test Accuracy: 0.9122000455856323
100%|██████████| 60/60 [00:01<00:00, 38.54it/s, accuracy=0.912, loss=0.275, lr=0.001]
Epoch 7, test Loss: 0.2596504405140877, test Accuracy: 0.9244000494480134
100%|██████████| 60/60 [00:01<00:00, 37.31it/s, accuracy=0.907, loss=0.293, lr=0.001]
Epoch 8, test Loss: 0.24268167465925217, test Accuracy: 0.9282000482082366
100%|██████████| 60/60 [00:01<00:00, 36.97it/s, accuracy=0.93, loss=0.251, lr=0.001]
Epoch 9, test Loss: 0.22453297376632692, test Accuracy: 0.9329000413417816
100%|██████████| 60/60 [00:01<00:00, 37.49it/s, accuracy=0.93, loss=0.258, lr=0.001]
Epoch 10, test Loss: 0.2074310526251793, test Accuracy: 0.9391000390052795
100%|██████████| 60/60 [00:01<00:00, 37.43it/s, accuracy=0.945, loss=0.213, lr=0.001]
Epoch 11, test Loss: 0.1980581246316433, test Accuracy: 0.9422000467777252
100%|██████████| 60/60 [00:01<00:00, 38.23it/s, accuracy=0.943, loss=0.217, lr=0.001]
Epoch 12, test Loss: 0.19006126448512078, test Accuracy: 0.9431000411510467
100%|██████████| 60/60 [00:01<00:00, 37.67it/s, accuracy=0.954, loss=0.172, lr=0.001]
Epoch 13, test Loss: 0.1777050383388996, test Accuracy: 0.9458000421524048
100%|██████████| 60/60 [00:01<00:00, 36.71it/s, accuracy=0.948, loss=0.175, lr=0.001]
Epoch 14, test Loss: 0.17280858755111694, test Accuracy: 0.9464000403881073
100%|██████████| 60/60 [00:01<00:00, 38.01it/s, accuracy=0.946, loss=0.186, lr=0.001]
Epoch 15, test Loss: 0.16691454201936723, test Accuracy: 0.949200040102005
100%|██████████| 60/60 [00:01<00:00, 37.97it/s, accuracy=0.952, loss=0.165, lr=0.001]
Epoch 16, test Loss: 0.16387526392936708, test Accuracy: 0.9485000371932983
100%|██████████| 60/60 [00:01<00:00, 36.39it/s, accuracy=0.952, loss=0.164, lr=0.001]
Epoch 17, test Loss: 0.15366466902196407, test Accuracy: 0.9516000390052796
100%|██████████| 60/60 [00:01<00:00, 37.86it/s, accuracy=0.945, loss=0.187, lr=0.001]
Epoch 18, test Loss: 0.1512805711477995, test Accuracy: 0.9532000422477722
100%|██████████| 60/60 [00:01<00:00, 37.61it/s, accuracy=0.953, loss=0.158, lr=0.001]
Epoch 19, test Loss: 0.15157018043100834, test Accuracy: 0.9514000475406647
100%|██████████| 60/60 [00:01<00:00, 36.75it/s, accuracy=0.962, loss=0.127, lr=0.001]
Epoch 20, test Loss: 0.1417049340903759, test Accuracy: 0.9557000458240509
[9]:
print("Top-1 Accuracy: ", max(test_accs))
print("Top-3 Accuracy: ", max(test_top3_accs))
print("CNN+MLP Model size:", len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))
Top-1 Accuracy: 0.9557000458240509
Top-3 Accuracy: 0.995000034570694
CNN+MLP Model size: 1934
The results show that QKAN can achieve better performance than MLP, but with less number of parameters.