Hybrid QKAN on CIFAR-100
In this example, we demonstrate the performance and scalability of the Hybrid QKAN (HQKAN) architecture when applied to a large-scale model using the CIFAR-100 dataset.
Motivation
Due to the architectural design of Kolmogorov–Arnold Networks (KAN), the number of parameters grows quadratically with respect to the number of input and output channels. This presents a scalability challenge for large models.
Here is how the size increases:
[1]:
from qkan import QKAN
model_1 = QKAN([10, 10])
model_2 = QKAN([10, 100])
print("Model 1 with shape [10, 10] has # of parameters:", model_1.param_size)
print("Model 2 with shape [10, 100] has # of parameters:", model_2.param_size)
del model_1, model_2
Model 1 with shape [10, 10] has # of parameters: 800
Model 2 with shape [10, 100] has # of parameters: 8000
To address this, we propose a hybrid architecture that integrates two fully connected networks (FCNs) to compress and reconstruct the feature representation. This forms an autoencoder-like bottleneck structure around the QKAN core.
We refer to this approach as Hybrid QKAN (HQKAN).
Architecture Overview
Encoder FCN compresses the high-dimensional input features.
QKAN Core processes the reduced latent representation.
Decoder FCN reconstructs the output to the original dimensionality.
This structure allows us to benefit from QKAN’s expressiveness while controlling parameter growth.
Experiment: CIFAR-100
We apply HQKAN to the CIFAR-100 dataset to evaluate its performance and efficiency.
[2]:
import random
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[3]:
class CNet(nn.Module):
def __init__(self, device):
super(CNet, self).__init__()
self.device = device
self.cnn1 = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=3, device=device
)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.cnn2 = nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=3, device=device
)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.cnn3 = nn.Conv2d(
in_channels=64, out_channels=64, kernel_size=3, device=device
)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = x.to(self.device)
x = self.cnn1(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.cnn2(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.cnn3(x)
x = self.relu(x)
x = self.maxpool3(x)
x = x.flatten(start_dim=1)
return x
[4]:
# Load CIFAR100 dataset
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.CIFAR100(
root="./data", train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR100(
root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=1000, shuffle=True)
testloader = DataLoader(testset, batch_size=1000, shuffle=False)
Files already downloaded and verified
Files already downloaded and verified
[5]:
in_feat = 2 * 2 * 64
out_feat = 100
in_resize = np.ceil(np.log2(in_feat)).astype(int)
out_resize = np.ceil(np.log2(out_feat)).astype(int)
model = nn.Sequential(
CNet(device=device),
nn.Linear(in_feat, in_resize, device=device),
QKAN([in_resize, out_resize], device=device),
nn.Linear(out_resize, out_feat, device=device),
)
criterion = nn.CrossEntropyLoss()
accs = []
losses = []
test_accs = []
test_losses = []
test_top5_accs = []
optimizer = optim.Adam(model.parameters(), lr=1e-3)
[6]:
def train(model, optimizer):
for epoch in range(50):
# Train
model.train()
with tqdm(trainloader) as pbar:
for i, (images, labels) in enumerate(pbar):
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
accs.append(accuracy.item())
losses.append(loss.item())
pbar.set_postfix(
loss=loss.item(),
accuracy=accuracy.item(),
lr=optimizer.param_groups[0]["lr"],
)
# test
model.eval()
test_loss = 0
test_accuracy = 0
test_top5_accuracy = 0
with torch.no_grad():
for images, labels in testloader:
output = model(images)
test_loss += criterion(output, labels.to(device)).item()
test_accuracy += (
(output.argmax(dim=1) == labels.to(device)).float().mean().item()
)
test_top5_accuracy += (
(output.topk(5, dim=1).indices == labels.to(device).unsqueeze(1)).any(dim=1).float().mean().item()
)
test_loss /= len(testloader)
test_accuracy /= len(testloader)
test_top5_accuracy /= len(testloader)
print(f"Epoch {epoch + 1}, test Loss: {test_loss}, test Accuracy: {test_accuracy}")
test_accs.append(test_accuracy)
test_losses.append(test_loss)
test_top5_accs.append(test_top5_accuracy)
[7]:
train(model, optimizer)
100%|██████████| 50/50 [00:02<00:00, 23.03it/s, accuracy=0.014, loss=4.57, lr=0.001]
Epoch 1, test Loss: 4.5694972515106205, test Accuracy: 0.01120000034570694
100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.032, loss=4.47, lr=0.001]
Epoch 2, test Loss: 4.473529720306397, test Accuracy: 0.02420000098645687
100%|██████████| 50/50 [00:01<00:00, 25.93it/s, accuracy=0.04, loss=4.35, lr=0.001]
Epoch 3, test Loss: 4.348425626754761, test Accuracy: 0.035800001583993435
100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.051, loss=4.18, lr=0.001]
Epoch 4, test Loss: 4.219951534271241, test Accuracy: 0.04830000214278698
100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.06, loss=4.1, lr=0.001]
Epoch 5, test Loss: 4.10424919128418, test Accuracy: 0.05960000306367874
100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.062, loss=4.01, lr=0.001]
Epoch 6, test Loss: 4.006009817123413, test Accuracy: 0.06800000295042992
100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.073, loss=3.88, lr=0.001]
Epoch 7, test Loss: 3.920482850074768, test Accuracy: 0.08030000403523445
100%|██████████| 50/50 [00:02<00:00, 24.94it/s, accuracy=0.089, loss=3.81, lr=0.001]
Epoch 8, test Loss: 3.8535669088363647, test Accuracy: 0.08670000433921814
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.093, loss=3.77, lr=0.001]
Epoch 9, test Loss: 3.811629557609558, test Accuracy: 0.09450000450015068
100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.109, loss=3.75, lr=0.001]
Epoch 10, test Loss: 3.7596314907073975, test Accuracy: 0.10250000432133674
100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.108, loss=3.71, lr=0.001]
Epoch 11, test Loss: 3.7226871490478515, test Accuracy: 0.10690000578761101
100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.122, loss=3.65, lr=0.001]
Epoch 12, test Loss: 3.6761600494384767, test Accuracy: 0.11960000693798065
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.126, loss=3.59, lr=0.001]
Epoch 13, test Loss: 3.6474289178848265, test Accuracy: 0.12440000697970391
100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.139, loss=3.53, lr=0.001]
Epoch 14, test Loss: 3.6189677715301514, test Accuracy: 0.1307000070810318
100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.126, loss=3.55, lr=0.001]
Epoch 15, test Loss: 3.5796807527542116, test Accuracy: 0.13510000705718994
100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.158, loss=3.47, lr=0.001]
Epoch 16, test Loss: 3.5481101989746096, test Accuracy: 0.1361000046133995
100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.182, loss=3.39, lr=0.001]
Epoch 17, test Loss: 3.531752014160156, test Accuracy: 0.1390000082552433
100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.168, loss=3.4, lr=0.001]
Epoch 18, test Loss: 3.5055885791778563, test Accuracy: 0.1440000057220459
100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.157, loss=3.36, lr=0.001]
Epoch 19, test Loss: 3.4795714139938356, test Accuracy: 0.14710000902414322
100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.164, loss=3.34, lr=0.001]
Epoch 20, test Loss: 3.441032814979553, test Accuracy: 0.15520000606775283
100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.185, loss=3.31, lr=0.001]
Epoch 21, test Loss: 3.4175992250442504, test Accuracy: 0.16120000928640366
100%|██████████| 50/50 [00:01<00:00, 25.20it/s, accuracy=0.192, loss=3.24, lr=0.001]
Epoch 22, test Loss: 3.3915664196014403, test Accuracy: 0.16930000633001327
100%|██████████| 50/50 [00:01<00:00, 25.91it/s, accuracy=0.183, loss=3.25, lr=0.001]
Epoch 23, test Loss: 3.4428197860717775, test Accuracy: 0.15350000709295272
100%|██████████| 50/50 [00:01<00:00, 25.95it/s, accuracy=0.213, loss=3.24, lr=0.001]
Epoch 24, test Loss: 3.362662839889526, test Accuracy: 0.16850000768899917
100%|██████████| 50/50 [00:01<00:00, 25.18it/s, accuracy=0.195, loss=3.23, lr=0.001]
Epoch 25, test Loss: 3.355527567863464, test Accuracy: 0.1708000048995018
100%|██████████| 50/50 [00:01<00:00, 25.54it/s, accuracy=0.196, loss=3.18, lr=0.001]
Epoch 26, test Loss: 3.358353543281555, test Accuracy: 0.1723000094294548
100%|██████████| 50/50 [00:01<00:00, 25.51it/s, accuracy=0.194, loss=3.22, lr=0.001]
Epoch 27, test Loss: 3.3171217918395994, test Accuracy: 0.1827000081539154
100%|██████████| 50/50 [00:01<00:00, 25.53it/s, accuracy=0.205, loss=3.12, lr=0.001]
Epoch 28, test Loss: 3.350891280174255, test Accuracy: 0.17330000698566436
100%|██████████| 50/50 [00:02<00:00, 24.87it/s, accuracy=0.213, loss=3.14, lr=0.001]
Epoch 29, test Loss: 3.2914602756500244, test Accuracy: 0.18220000863075256
100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.196, loss=3.16, lr=0.001]
Epoch 30, test Loss: 3.2760613441467283, test Accuracy: 0.18880000859498977
100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.212, loss=3.11, lr=0.001]
Epoch 31, test Loss: 3.2763137340545656, test Accuracy: 0.18600000739097594
100%|██████████| 50/50 [00:02<00:00, 24.93it/s, accuracy=0.237, loss=3.05, lr=0.001]
Epoch 32, test Loss: 3.269388508796692, test Accuracy: 0.1886000081896782
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.228, loss=3, lr=0.001]
Epoch 33, test Loss: 3.280174207687378, test Accuracy: 0.18370001018047333
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.234, loss=3.08, lr=0.001]
Epoch 34, test Loss: 3.247648596763611, test Accuracy: 0.19470000714063646
100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.206, loss=3.11, lr=0.001]
Epoch 35, test Loss: 3.236310625076294, test Accuracy: 0.2018000066280365
100%|██████████| 50/50 [00:02<00:00, 24.92it/s, accuracy=0.219, loss=3.04, lr=0.001]
Epoch 36, test Loss: 3.2433222055435182, test Accuracy: 0.19420000910758972
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.223, loss=3.05, lr=0.001]
Epoch 37, test Loss: 3.2257045030593874, test Accuracy: 0.19530000984668733
100%|██████████| 50/50 [00:01<00:00, 25.63it/s, accuracy=0.263, loss=2.96, lr=0.001]
Epoch 38, test Loss: 3.221918821334839, test Accuracy: 0.20050000846385957
100%|██████████| 50/50 [00:02<00:00, 24.86it/s, accuracy=0.254, loss=2.91, lr=0.001]
Epoch 39, test Loss: 3.230287194252014, test Accuracy: 0.20160000771284103
100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.245, loss=2.96, lr=0.001]
Epoch 40, test Loss: 3.236723041534424, test Accuracy: 0.19820000976324081
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.237, loss=2.97, lr=0.001]
Epoch 41, test Loss: 3.231258010864258, test Accuracy: 0.20250001102685927
100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.228, loss=2.95, lr=0.001]
Epoch 42, test Loss: 3.220815086364746, test Accuracy: 0.20580001026391984
100%|██████████| 50/50 [00:02<00:00, 24.80it/s, accuracy=0.246, loss=2.87, lr=0.001]
Epoch 43, test Loss: 3.2338695764541625, test Accuracy: 0.2041000097990036
100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.262, loss=2.88, lr=0.001]
Epoch 44, test Loss: 3.2135489225387572, test Accuracy: 0.20290001034736632
100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.274, loss=2.84, lr=0.001]
Epoch 45, test Loss: 3.204585647583008, test Accuracy: 0.20940001159906388
100%|██████████| 50/50 [00:02<00:00, 24.89it/s, accuracy=0.269, loss=2.83, lr=0.001]
Epoch 46, test Loss: 3.1972398281097414, test Accuracy: 0.21000000983476638
100%|██████████| 50/50 [00:01<00:00, 25.63it/s, accuracy=0.236, loss=2.92, lr=0.001]
Epoch 47, test Loss: 3.2239073276519776, test Accuracy: 0.20540000945329667
100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.243, loss=2.93, lr=0.001]
Epoch 48, test Loss: 3.1972729444503782, test Accuracy: 0.20940001010894777
100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.255, loss=2.86, lr=0.001]
Epoch 49, test Loss: 3.210345673561096, test Accuracy: 0.20560001134872435
100%|██████████| 50/50 [00:02<00:00, 24.92it/s, accuracy=0.262, loss=2.77, lr=0.001]
Epoch 50, test Loss: 3.1845874547958375, test Accuracy: 0.21320001184940338
[8]:
print("Top-1 Accuracy: ", max(test_accs))
print("Top-5 Accuracy: ", max(test_top5_accs))
print("CNN+HQKAN Model size:", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))
cnn = CNet(device="cpu")
cnn_param = len(torch.concat([param.cpu().detach().flatten() for param in cnn.parameters() if param.requires_grad]))
del cnn
print("CNN part size:", cnn_param)
print("HQKAN part size:", param - cnn_param)
Top-1 Accuracy: 0.21320001184940338
Top-5 Accuracy: 0.5048000246286393
CNN+HQKAN Model size: 59624
CNN part size: 56320
HQKAN part size: 3304
To enhance model capacity, we increase the input and output channels of the QKAN core. This enables learning of more complex feature representations while maintaining the parameter budget via the hybrid bottleneck design.
[9]:
in_resize = np.ceil(np.log2(in_feat)).astype(int) * 4
out_resize = np.ceil(np.log2(out_feat)).astype(int) * 4
model = nn.Sequential(
CNet(device=device),
nn.Linear(in_feat, in_resize, device=device),
QKAN([in_resize, out_resize], device=device),
nn.Linear(out_resize, out_feat, device=device),
)
criterion = nn.CrossEntropyLoss()
accs = []
losses = []
test_accs = []
test_losses = []
test_top5_accs = []
optimizer = optim.Adam(model.parameters(), lr=1e-3)
[10]:
train(model, optimizer)
100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.04, loss=4.37, lr=0.001]
Epoch 1, test Loss: 4.359431600570678, test Accuracy: 0.04140000194311142
100%|██████████| 50/50 [00:02<00:00, 23.53it/s, accuracy=0.055, loss=4.03, lr=0.001]
Epoch 2, test Loss: 4.007006192207337, test Accuracy: 0.08220000267028808
100%|██████████| 50/50 [00:02<00:00, 23.57it/s, accuracy=0.107, loss=3.87, lr=0.001]
Epoch 3, test Loss: 3.803478169441223, test Accuracy: 0.11150000542402268
100%|██████████| 50/50 [00:02<00:00, 23.61it/s, accuracy=0.149, loss=3.56, lr=0.001]
Epoch 4, test Loss: 3.689195442199707, test Accuracy: 0.13470000326633452
100%|██████████| 50/50 [00:02<00:00, 23.03it/s, accuracy=0.145, loss=3.6, lr=0.001]
Epoch 5, test Loss: 3.5530738115310667, test Accuracy: 0.15630000829696655
100%|██████████| 50/50 [00:02<00:00, 23.73it/s, accuracy=0.159, loss=3.45, lr=0.001]
Epoch 6, test Loss: 3.47371768951416, test Accuracy: 0.1645000085234642
100%|██████████| 50/50 [00:02<00:00, 23.71it/s, accuracy=0.194, loss=3.31, lr=0.001]
Epoch 7, test Loss: 3.369604206085205, test Accuracy: 0.18320000916719437
100%|██████████| 50/50 [00:02<00:00, 23.06it/s, accuracy=0.198, loss=3.27, lr=0.001]
Epoch 8, test Loss: 3.2939977645874023, test Accuracy: 0.1992000088095665
100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.209, loss=3.19, lr=0.001]
Epoch 9, test Loss: 3.2363681554794312, test Accuracy: 0.20680001080036164
100%|██████████| 50/50 [00:02<00:00, 23.72it/s, accuracy=0.225, loss=3.19, lr=0.001]
Epoch 10, test Loss: 3.1699442625045777, test Accuracy: 0.21730001121759415
100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.239, loss=3.1, lr=0.001]
Epoch 11, test Loss: 3.124575066566467, test Accuracy: 0.22790001183748246
100%|██████████| 50/50 [00:02<00:00, 23.10it/s, accuracy=0.254, loss=3.03, lr=0.001]
Epoch 12, test Loss: 3.0657344341278074, test Accuracy: 0.2417000100016594
100%|██████████| 50/50 [00:02<00:00, 23.64it/s, accuracy=0.235, loss=2.99, lr=0.001]
Epoch 13, test Loss: 3.0330076456069945, test Accuracy: 0.24390001147985457
100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.267, loss=2.93, lr=0.001]
Epoch 14, test Loss: 2.970594549179077, test Accuracy: 0.25330001413822173
100%|██████████| 50/50 [00:02<00:00, 22.92it/s, accuracy=0.267, loss=2.87, lr=0.001]
Epoch 15, test Loss: 2.9353652477264403, test Accuracy: 0.26220000833272933
100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.297, loss=2.79, lr=0.001]
Epoch 16, test Loss: 2.929064393043518, test Accuracy: 0.26770001500844953
100%|██████████| 50/50 [00:02<00:00, 23.63it/s, accuracy=0.332, loss=2.71, lr=0.001]
Epoch 17, test Loss: 2.878641438484192, test Accuracy: 0.2755000114440918
100%|██████████| 50/50 [00:02<00:00, 23.68it/s, accuracy=0.315, loss=2.7, lr=0.001]
Epoch 18, test Loss: 2.865756464004517, test Accuracy: 0.2796000182628632
100%|██████████| 50/50 [00:02<00:00, 23.00it/s, accuracy=0.292, loss=2.73, lr=0.001]
Epoch 19, test Loss: 2.8631765127182005, test Accuracy: 0.2791000097990036
100%|██████████| 50/50 [00:02<00:00, 23.68it/s, accuracy=0.299, loss=2.71, lr=0.001]
Epoch 20, test Loss: 2.8106075525283813, test Accuracy: 0.28650001883506776
100%|██████████| 50/50 [00:02<00:00, 23.81it/s, accuracy=0.319, loss=2.65, lr=0.001]
Epoch 21, test Loss: 2.806682085990906, test Accuracy: 0.28940001130104065
100%|██████████| 50/50 [00:02<00:00, 23.12it/s, accuracy=0.335, loss=2.59, lr=0.001]
Epoch 22, test Loss: 2.7785612821578978, test Accuracy: 0.2944000124931335
100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.33, loss=2.55, lr=0.001]
Epoch 23, test Loss: 2.745818758010864, test Accuracy: 0.3020000159740448
100%|██████████| 50/50 [00:02<00:00, 23.60it/s, accuracy=0.331, loss=2.55, lr=0.001]
Epoch 24, test Loss: 2.7335641622543334, test Accuracy: 0.30610001385211943
100%|██████████| 50/50 [00:02<00:00, 23.52it/s, accuracy=0.327, loss=2.57, lr=0.001]
Epoch 25, test Loss: 2.7024224758148194, test Accuracy: 0.31030001640319826
100%|██████████| 50/50 [00:02<00:00, 23.07it/s, accuracy=0.359, loss=2.47, lr=0.001]
Epoch 26, test Loss: 2.6971447467803955, test Accuracy: 0.3104000151157379
100%|██████████| 50/50 [00:02<00:00, 23.47it/s, accuracy=0.351, loss=2.46, lr=0.001]
Epoch 27, test Loss: 2.6772998094558718, test Accuracy: 0.31400001645088194
100%|██████████| 50/50 [00:02<00:00, 23.52it/s, accuracy=0.387, loss=2.35, lr=0.001]
Epoch 28, test Loss: 2.66512246131897, test Accuracy: 0.32040001153945924
100%|██████████| 50/50 [00:02<00:00, 23.01it/s, accuracy=0.382, loss=2.42, lr=0.001]
Epoch 29, test Loss: 2.667982268333435, test Accuracy: 0.3170000106096268
100%|██████████| 50/50 [00:02<00:00, 23.75it/s, accuracy=0.369, loss=2.41, lr=0.001]
Epoch 30, test Loss: 2.629500079154968, test Accuracy: 0.325500014424324
100%|██████████| 50/50 [00:02<00:00, 23.74it/s, accuracy=0.394, loss=2.38, lr=0.001]
Epoch 31, test Loss: 2.6221821546554565, test Accuracy: 0.3290000170469284
100%|██████████| 50/50 [00:02<00:00, 23.73it/s, accuracy=0.412, loss=2.31, lr=0.001]
Epoch 32, test Loss: 2.6159011363983153, test Accuracy: 0.32920001745224
100%|██████████| 50/50 [00:02<00:00, 23.05it/s, accuracy=0.403, loss=2.38, lr=0.001]
Epoch 33, test Loss: 2.601720857620239, test Accuracy: 0.33770001232624053
100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.392, loss=2.32, lr=0.001]
Epoch 34, test Loss: 2.603303813934326, test Accuracy: 0.33710001707077025
100%|██████████| 50/50 [00:02<00:00, 23.66it/s, accuracy=0.398, loss=2.38, lr=0.001]
Epoch 35, test Loss: 2.6415931463241575, test Accuracy: 0.3316000163555145
100%|██████████| 50/50 [00:02<00:00, 23.13it/s, accuracy=0.37, loss=2.33, lr=0.001]
Epoch 36, test Loss: 2.5769686698913574, test Accuracy: 0.3406000167131424
100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.384, loss=2.36, lr=0.001]
Epoch 37, test Loss: 2.5785749912261964, test Accuracy: 0.3406000167131424
100%|██████████| 50/50 [00:02<00:00, 23.70it/s, accuracy=0.405, loss=2.31, lr=0.001]
Epoch 38, test Loss: 2.596335530281067, test Accuracy: 0.33610001504421233
100%|██████████| 50/50 [00:02<00:00, 23.64it/s, accuracy=0.434, loss=2.14, lr=0.001]
Epoch 39, test Loss: 2.573742628097534, test Accuracy: 0.34020001590251925
100%|██████████| 50/50 [00:02<00:00, 23.05it/s, accuracy=0.419, loss=2.22, lr=0.001]
Epoch 40, test Loss: 2.577555251121521, test Accuracy: 0.3434000164270401
100%|██████████| 50/50 [00:02<00:00, 23.61it/s, accuracy=0.408, loss=2.27, lr=0.001]
Epoch 41, test Loss: 2.5758074045181276, test Accuracy: 0.33860001564025877
100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.413, loss=2.2, lr=0.001]
Epoch 42, test Loss: 2.564303970336914, test Accuracy: 0.34670001566410064
100%|██████████| 50/50 [00:02<00:00, 23.11it/s, accuracy=0.401, loss=2.24, lr=0.001]
Epoch 43, test Loss: 2.563422918319702, test Accuracy: 0.3433000147342682
100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.449, loss=2.14, lr=0.001]
Epoch 44, test Loss: 2.57224817276001, test Accuracy: 0.3441000163555145
100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.44, loss=2.08, lr=0.001]
Epoch 45, test Loss: 2.5427053689956667, test Accuracy: 0.35400002002716063
100%|██████████| 50/50 [00:02<00:00, 23.76it/s, accuracy=0.437, loss=2.13, lr=0.001]
Epoch 46, test Loss: 2.564718508720398, test Accuracy: 0.3476000130176544
100%|██████████| 50/50 [00:02<00:00, 22.99it/s, accuracy=0.395, loss=2.21, lr=0.001]
Epoch 47, test Loss: 2.57286651134491, test Accuracy: 0.35210002064704893
100%|██████████| 50/50 [00:02<00:00, 23.53it/s, accuracy=0.423, loss=2.16, lr=0.001]
Epoch 48, test Loss: 2.555943727493286, test Accuracy: 0.34780001938343047
100%|██████████| 50/50 [00:02<00:00, 23.55it/s, accuracy=0.454, loss=2.12, lr=0.001]
Epoch 49, test Loss: 2.524978685379028, test Accuracy: 0.36210001111030576
100%|██████████| 50/50 [00:02<00:00, 22.87it/s, accuracy=0.437, loss=2.07, lr=0.001]
Epoch 50, test Loss: 2.5304693937301637, test Accuracy: 0.35640001893043516
[11]:
print("Top-1 Accuracy: ", max(test_accs))
print("Top-5 Accuracy: ", max(test_top5_accs))
print("CNN+HQKAN-44 Model size:", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))
print("CNN part size:", cnn_param)
print("HQKAN-44 part size:", param - cnn_param)
Top-1 Accuracy: 0.36210001111030576
Top-5 Accuracy: 0.6717000305652618
CNN+HQKAN-44 Model size: 74612
CNN part size: 56320
HQKAN-44 part size: 18292
Compare to MLP
[12]:
class CNN(nn.Module):
def __init__(self, device):
super(CNN, self).__init__()
self.cnet = CNet(device)
self.fc1 = nn.Linear(256, 192).to(device)
self.fc2 = nn.Linear(192, 128).to(device)
self.fc3 = nn.Linear(128, 100).to(device)
self.device = device
def forward(self, x):
x.to(self.device)
x = self.cnet(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
model = CNN(device=device)
criterion = nn.CrossEntropyLoss()
accs = []
losses = []
test_accs = []
test_losses = []
test_top5_accs = []
optimizer = optim.Adam(model.parameters(), lr=1e-3)
[13]:
train(model, optimizer)
100%|██████████| 50/50 [00:01<00:00, 26.49it/s, accuracy=0.073, loss=4.02, lr=0.001]
Epoch 1, test Loss: 4.0597692966461185, test Accuracy: 0.07460000440478325
100%|██████████| 50/50 [00:01<00:00, 26.60it/s, accuracy=0.085, loss=3.85, lr=0.001]
Epoch 2, test Loss: 3.8250439167022705, test Accuracy: 0.10800000503659249
100%|██████████| 50/50 [00:01<00:00, 26.46it/s, accuracy=0.127, loss=3.7, lr=0.001]
Epoch 3, test Loss: 3.6398918867111205, test Accuracy: 0.14570000618696213
100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.171, loss=3.47, lr=0.001]
Epoch 4, test Loss: 3.501902627944946, test Accuracy: 0.174700003862381
100%|██████████| 50/50 [00:01<00:00, 26.18it/s, accuracy=0.217, loss=3.27, lr=0.001]
Epoch 5, test Loss: 3.39032986164093, test Accuracy: 0.18960001021623613
100%|██████████| 50/50 [00:01<00:00, 25.66it/s, accuracy=0.205, loss=3.23, lr=0.001]
Epoch 6, test Loss: 3.329215979576111, test Accuracy: 0.20160000920295715
100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.208, loss=3.22, lr=0.001]
Epoch 7, test Loss: 3.2377456426620483, test Accuracy: 0.21970001012086868
100%|██████████| 50/50 [00:01<00:00, 26.28it/s, accuracy=0.256, loss=3.09, lr=0.001]
Epoch 8, test Loss: 3.1636680603027343, test Accuracy: 0.22900001108646392
100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.243, loss=3.04, lr=0.001]
Epoch 9, test Loss: 3.1480681180953978, test Accuracy: 0.2313000112771988
100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.244, loss=3.1, lr=0.001]
Epoch 10, test Loss: 3.0826433897018433, test Accuracy: 0.2523000121116638
100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.268, loss=3.03, lr=0.001]
Epoch 11, test Loss: 3.0320362567901613, test Accuracy: 0.25540001392364503
100%|██████████| 50/50 [00:01<00:00, 26.39it/s, accuracy=0.271, loss=2.9, lr=0.001]
Epoch 12, test Loss: 3.0019490718841553, test Accuracy: 0.26240001022815707
100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.274, loss=2.93, lr=0.001]
Epoch 13, test Loss: 2.9862070083618164, test Accuracy: 0.2681000143289566
100%|██████████| 50/50 [00:01<00:00, 26.41it/s, accuracy=0.296, loss=2.82, lr=0.001]
Epoch 14, test Loss: 2.9344271421432495, test Accuracy: 0.28090001046657564
100%|██████████| 50/50 [00:01<00:00, 26.28it/s, accuracy=0.297, loss=2.84, lr=0.001]
Epoch 15, test Loss: 2.9287622928619386, test Accuracy: 0.27840001285076144
100%|██████████| 50/50 [00:01<00:00, 26.18it/s, accuracy=0.304, loss=2.72, lr=0.001]
Epoch 16, test Loss: 2.8981085062026977, test Accuracy: 0.2864000201225281
100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.292, loss=2.8, lr=0.001]
Epoch 17, test Loss: 2.864175868034363, test Accuracy: 0.2898000180721283
100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.305, loss=2.71, lr=0.001]
Epoch 18, test Loss: 2.830798387527466, test Accuracy: 0.2998000144958496
100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.321, loss=2.6, lr=0.001]
Epoch 19, test Loss: 2.8063244581222535, test Accuracy: 0.30420001447200773
100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.319, loss=2.68, lr=0.001]
Epoch 20, test Loss: 2.8006964921951294, test Accuracy: 0.3048000156879425
100%|██████████| 50/50 [00:01<00:00, 26.47it/s, accuracy=0.328, loss=2.66, lr=0.001]
Epoch 21, test Loss: 2.769916844367981, test Accuracy: 0.3151000142097473
100%|██████████| 50/50 [00:01<00:00, 26.33it/s, accuracy=0.312, loss=2.62, lr=0.001]
Epoch 22, test Loss: 2.7571829080581667, test Accuracy: 0.3123000174760818
100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.36, loss=2.53, lr=0.001]
Epoch 23, test Loss: 2.733232855796814, test Accuracy: 0.31750001311302184
100%|██████████| 50/50 [00:01<00:00, 26.35it/s, accuracy=0.382, loss=2.41, lr=0.001]
Epoch 24, test Loss: 2.7194632291793823, test Accuracy: 0.3200000137090683
100%|██████████| 50/50 [00:01<00:00, 26.36it/s, accuracy=0.377, loss=2.47, lr=0.001]
Epoch 25, test Loss: 2.756288003921509, test Accuracy: 0.3152000159025192
100%|██████████| 50/50 [00:01<00:00, 26.49it/s, accuracy=0.361, loss=2.44, lr=0.001]
Epoch 26, test Loss: 2.691506338119507, test Accuracy: 0.3308000206947327
100%|██████████| 50/50 [00:01<00:00, 25.68it/s, accuracy=0.369, loss=2.46, lr=0.001]
Epoch 27, test Loss: 2.689697813987732, test Accuracy: 0.3318000167608261
100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.371, loss=2.39, lr=0.001]
Epoch 28, test Loss: 2.6991366863250734, test Accuracy: 0.32980001270771026
100%|██████████| 50/50 [00:01<00:00, 26.22it/s, accuracy=0.383, loss=2.44, lr=0.001]
Epoch 29, test Loss: 2.647859072685242, test Accuracy: 0.33790001571178435
100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.376, loss=2.41, lr=0.001]
Epoch 30, test Loss: 2.6324229001998902, test Accuracy: 0.3419000118970871
100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.413, loss=2.25, lr=0.001]
Epoch 31, test Loss: 2.6560914516448975, test Accuracy: 0.3348000138998032
100%|██████████| 50/50 [00:01<00:00, 26.35it/s, accuracy=0.409, loss=2.32, lr=0.001]
Epoch 32, test Loss: 2.624714732170105, test Accuracy: 0.3481000155210495
100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.358, loss=2.41, lr=0.001]
Epoch 33, test Loss: 2.610011410713196, test Accuracy: 0.344700014591217
100%|██████████| 50/50 [00:01<00:00, 25.78it/s, accuracy=0.412, loss=2.26, lr=0.001]
Epoch 34, test Loss: 2.5931995153427123, test Accuracy: 0.3505000174045563
100%|██████████| 50/50 [00:01<00:00, 26.40it/s, accuracy=0.406, loss=2.29, lr=0.001]
Epoch 35, test Loss: 2.5763473510742188, test Accuracy: 0.3550000131130219
100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.416, loss=2.22, lr=0.001]
Epoch 36, test Loss: 2.576857900619507, test Accuracy: 0.3531000167131424
100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.42, loss=2.17, lr=0.001]
Epoch 37, test Loss: 2.5906506299972536, test Accuracy: 0.35740001797676085
100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.416, loss=2.22, lr=0.001]
Epoch 38, test Loss: 2.569392275810242, test Accuracy: 0.3538000226020813
100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.438, loss=2.08, lr=0.001]
Epoch 39, test Loss: 2.5566878080368043, test Accuracy: 0.35710001587867735
100%|██████████| 50/50 [00:01<00:00, 26.54it/s, accuracy=0.416, loss=2.19, lr=0.001]
Epoch 40, test Loss: 2.556428050994873, test Accuracy: 0.3639000207185745
100%|██████████| 50/50 [00:01<00:00, 25.68it/s, accuracy=0.434, loss=2.13, lr=0.001]
Epoch 41, test Loss: 2.542978119850159, test Accuracy: 0.36690001785755155
100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.445, loss=2.12, lr=0.001]
Epoch 42, test Loss: 2.5442334413528442, test Accuracy: 0.362700018286705
100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.431, loss=2.17, lr=0.001]
Epoch 43, test Loss: 2.5535474538803102, test Accuracy: 0.3637000173330307
100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.445, loss=2.1, lr=0.001]
Epoch 44, test Loss: 2.530148911476135, test Accuracy: 0.3738000154495239
100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.466, loss=2.03, lr=0.001]
Epoch 45, test Loss: 2.527567744255066, test Accuracy: 0.36940001845359804
100%|██████████| 50/50 [00:01<00:00, 26.55it/s, accuracy=0.465, loss=2.01, lr=0.001]
Epoch 46, test Loss: 2.5381455183029176, test Accuracy: 0.37100001573562624
100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.488, loss=1.99, lr=0.001]
Epoch 47, test Loss: 2.532746434211731, test Accuracy: 0.3675000160932541
100%|██████████| 50/50 [00:01<00:00, 25.78it/s, accuracy=0.494, loss=1.91, lr=0.001]
Epoch 48, test Loss: 2.5316006898880006, test Accuracy: 0.37280001640319826
100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.473, loss=1.96, lr=0.001]
Epoch 49, test Loss: 2.561800479888916, test Accuracy: 0.3666000127792358
100%|██████████| 50/50 [00:01<00:00, 26.64it/s, accuracy=0.461, loss=1.99, lr=0.001]
Epoch 50, test Loss: 2.515561318397522, test Accuracy: 0.3707000195980072
[14]:
print("Top-1 Accuracy: ", max(test_accs))
print("Top-5 Accuracy: ", max(test_top5_accs))
print("CNN+MLP Model size:", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))
print("CNN part size:", cnn_param)
print("MLP part size:", param - cnn_param)
Top-1 Accuracy: 0.3738000154495239
Top-5 Accuracy: 0.6787000298500061
CNN+MLP Model size: 143268
CNN part size: 56320
MLP part size: 86948
Results
HQKAN-44 achieves comparable accuracy to an MLP baseline.
It uses only 20% of the parameters compared to the MLP.
Training time is also comparable to MLP, and significantly faster than many other quantum-inspired machine learning (QML) models when run on classical simulators.
Summary
Hybrid QKAN offers a scalable and efficient alternative to traditional MLPs and other QML models:
✔ Comparable accuracy
✔ 5× fewer parameters
✔ Faster training than typical QML approaches
This demonstrates HQKAN’s potential as a practical and effective model architecture for modern deep learning tasks.