{ "cells": [ { "cell_type": "markdown", "id": "77b52b46", "metadata": {}, "source": [ "# Hybrid QKAN on CIFAR-100\n", "\n", "In this example, we demonstrate the performance and scalability of the **Hybrid QKAN (HQKAN)** architecture when applied to a large-scale model using the **CIFAR-100** dataset.\n", "\n", "## Motivation\n", "\n", "Due to the architectural design of Kolmogorov–Arnold Networks (KAN), the number of parameters grows **quadratically** with respect to the number of input and output channels. This presents a scalability challenge for large models.\n", "\n", "Here is how the size increases:" ] }, { "cell_type": "code", "execution_count": 1, "id": "0043e4d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model 1 with shape [10, 10] has # of parameters: 800\n", "Model 2 with shape [10, 100] has # of parameters: 8000\n" ] } ], "source": [ "from qkan import QKAN\n", "\n", "model_1 = QKAN([10, 10])\n", "model_2 = QKAN([10, 100])\n", "\n", "print(\"Model 1 with shape [10, 10] has # of parameters:\", model_1.param_size)\n", "print(\"Model 2 with shape [10, 100] has # of parameters:\", model_2.param_size)\n", "del model_1, model_2" ] }, { "cell_type": "markdown", "id": "2bf2a7b5", "metadata": {}, "source": [ "To address this, we propose a **hybrid architecture** that integrates two fully connected networks (FCNs) to **compress** and **reconstruct** the feature representation. This forms an **autoencoder-like bottleneck structure** around the QKAN core.\n", "\n", "We refer to this approach as **Hybrid QKAN (HQKAN)**.\n", "\n", "## Architecture Overview\n", "\n", "- **Encoder FCN** compresses the high-dimensional input features.\n", "- **QKAN Core** processes the reduced latent representation.\n", "- **Decoder FCN** reconstructs the output to the original dimensionality.\n", "- This structure allows us to benefit from QKAN's expressiveness while controlling parameter growth.\n", "\n", "## Experiment: CIFAR-100\n", "\n", "We apply HQKAN to the CIFAR-100 dataset to evaluate its performance and efficiency." ] }, { "cell_type": "code", "execution_count": 2, "id": "b5e34c31", "metadata": {}, "outputs": [], "source": [ "import random\n", "import sys\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "26b3be04", "metadata": {}, "outputs": [], "source": [ "class CNet(nn.Module):\n", " def __init__(self, device):\n", " super(CNet, self).__init__()\n", "\n", " self.device = device\n", "\n", " self.cnn1 = nn.Conv2d(\n", " in_channels=3, out_channels=32, kernel_size=3, device=device\n", " )\n", " self.relu = nn.ReLU()\n", " self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.cnn2 = nn.Conv2d(\n", " in_channels=32, out_channels=64, kernel_size=3, device=device\n", " )\n", " self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.cnn3 = nn.Conv2d(\n", " in_channels=64, out_channels=64, kernel_size=3, device=device\n", " )\n", " self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)\n", "\n", " def forward(self, x):\n", " x = x.to(self.device)\n", " x = self.cnn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool1(x)\n", " x = self.cnn2(x)\n", " x = self.relu(x)\n", " x = self.maxpool2(x)\n", " x = self.cnn3(x)\n", " x = self.relu(x)\n", " x = self.maxpool3(x)\n", " x = x.flatten(start_dim=1)\n", " return x" ] }, { "cell_type": "code", "execution_count": 4, "id": "d687577b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "# Load CIFAR100 dataset\n", "transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]\n", ")\n", "trainset = torchvision.datasets.CIFAR100(\n", " root=\"./data\", train=True, download=True, transform=transform\n", ")\n", "testset = torchvision.datasets.CIFAR100(\n", " root=\"./data\", train=False, download=True, transform=transform\n", ")\n", "trainloader = DataLoader(trainset, batch_size=1000, shuffle=True)\n", "testloader = DataLoader(testset, batch_size=1000, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "id": "394d6636", "metadata": {}, "outputs": [], "source": [ "in_feat = 2 * 2 * 64\n", "out_feat = 100\n", "in_resize = np.ceil(np.log2(in_feat)).astype(int)\n", "out_resize = np.ceil(np.log2(out_feat)).astype(int)\n", "model = nn.Sequential(\n", " CNet(device=device),\n", " nn.Linear(in_feat, in_resize, device=device),\n", " QKAN([in_resize, out_resize], device=device),\n", " nn.Linear(out_resize, out_feat, device=device),\n", ")\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "accs = []\n", "losses = []\n", "test_accs = []\n", "test_losses = []\n", "test_top5_accs = []\n", "\n", "optimizer = optim.Adam(model.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 6, "id": "0c9a4c57", "metadata": {}, "outputs": [], "source": [ "def train(model, optimizer):\n", " for epoch in range(50):\n", " # Train\n", " model.train()\n", " with tqdm(trainloader) as pbar:\n", " for i, (images, labels) in enumerate(pbar):\n", " optimizer.zero_grad()\n", " output = model(images)\n", " loss = criterion(output, labels.to(device))\n", " loss.backward()\n", " optimizer.step()\n", " accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()\n", " accs.append(accuracy.item())\n", " losses.append(loss.item())\n", " pbar.set_postfix(\n", " loss=loss.item(),\n", " accuracy=accuracy.item(),\n", " lr=optimizer.param_groups[0][\"lr\"],\n", " )\n", "\n", " # test\n", " model.eval()\n", " test_loss = 0\n", " test_accuracy = 0\n", " test_top5_accuracy = 0\n", " with torch.no_grad():\n", " for images, labels in testloader:\n", " output = model(images)\n", " test_loss += criterion(output, labels.to(device)).item()\n", " test_accuracy += (\n", " (output.argmax(dim=1) == labels.to(device)).float().mean().item()\n", " )\n", " test_top5_accuracy += (\n", " (output.topk(5, dim=1).indices == labels.to(device).unsqueeze(1)).any(dim=1).float().mean().item()\n", " )\n", " test_loss /= len(testloader)\n", " test_accuracy /= len(testloader)\n", " test_top5_accuracy /= len(testloader)\n", "\n", " print(f\"Epoch {epoch + 1}, test Loss: {test_loss}, test Accuracy: {test_accuracy}\")\n", " test_accs.append(test_accuracy)\n", " test_losses.append(test_loss)\n", " test_top5_accs.append(test_top5_accuracy)" ] }, { "cell_type": "code", "execution_count": 7, "id": "876c4fd1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.03it/s, accuracy=0.014, loss=4.57, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, test Loss: 4.5694972515106205, test Accuracy: 0.01120000034570694\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.032, loss=4.47, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2, test Loss: 4.473529720306397, test Accuracy: 0.02420000098645687\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.93it/s, accuracy=0.04, loss=4.35, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3, test Loss: 4.348425626754761, test Accuracy: 0.035800001583993435\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.051, loss=4.18, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4, test Loss: 4.219951534271241, test Accuracy: 0.04830000214278698\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.06, loss=4.1, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5, test Loss: 4.10424919128418, test Accuracy: 0.05960000306367874\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.062, loss=4.01, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6, test Loss: 4.006009817123413, test Accuracy: 0.06800000295042992\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.073, loss=3.88, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7, test Loss: 3.920482850074768, test Accuracy: 0.08030000403523445\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.94it/s, accuracy=0.089, loss=3.81, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8, test Loss: 3.8535669088363647, test Accuracy: 0.08670000433921814\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.093, loss=3.77, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9, test Loss: 3.811629557609558, test Accuracy: 0.09450000450015068\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.109, loss=3.75, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10, test Loss: 3.7596314907073975, test Accuracy: 0.10250000432133674\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.108, loss=3.71, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11, test Loss: 3.7226871490478515, test Accuracy: 0.10690000578761101\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.122, loss=3.65, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 12, test Loss: 3.6761600494384767, test Accuracy: 0.11960000693798065\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.126, loss=3.59, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13, test Loss: 3.6474289178848265, test Accuracy: 0.12440000697970391\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.139, loss=3.53, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14, test Loss: 3.6189677715301514, test Accuracy: 0.1307000070810318\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.126, loss=3.55, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15, test Loss: 3.5796807527542116, test Accuracy: 0.13510000705718994\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.158, loss=3.47, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16, test Loss: 3.5481101989746096, test Accuracy: 0.1361000046133995\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.182, loss=3.39, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17, test Loss: 3.531752014160156, test Accuracy: 0.1390000082552433\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.91it/s, accuracy=0.168, loss=3.4, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18, test Loss: 3.5055885791778563, test Accuracy: 0.1440000057220459\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.157, loss=3.36, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19, test Loss: 3.4795714139938356, test Accuracy: 0.14710000902414322\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.164, loss=3.34, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 20, test Loss: 3.441032814979553, test Accuracy: 0.15520000606775283\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.185, loss=3.31, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 21, test Loss: 3.4175992250442504, test Accuracy: 0.16120000928640366\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.20it/s, accuracy=0.192, loss=3.24, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 22, test Loss: 3.3915664196014403, test Accuracy: 0.16930000633001327\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.91it/s, accuracy=0.183, loss=3.25, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 23, test Loss: 3.4428197860717775, test Accuracy: 0.15350000709295272\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.95it/s, accuracy=0.213, loss=3.24, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 24, test Loss: 3.362662839889526, test Accuracy: 0.16850000768899917\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.18it/s, accuracy=0.195, loss=3.23, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 25, test Loss: 3.355527567863464, test Accuracy: 0.1708000048995018\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.54it/s, accuracy=0.196, loss=3.18, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 26, test Loss: 3.358353543281555, test Accuracy: 0.1723000094294548\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.51it/s, accuracy=0.194, loss=3.22, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 27, test Loss: 3.3171217918395994, test Accuracy: 0.1827000081539154\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.53it/s, accuracy=0.205, loss=3.12, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 28, test Loss: 3.350891280174255, test Accuracy: 0.17330000698566436\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.87it/s, accuracy=0.213, loss=3.14, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 29, test Loss: 3.2914602756500244, test Accuracy: 0.18220000863075256\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.196, loss=3.16, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 30, test Loss: 3.2760613441467283, test Accuracy: 0.18880000859498977\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.212, loss=3.11, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 31, test Loss: 3.2763137340545656, test Accuracy: 0.18600000739097594\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.93it/s, accuracy=0.237, loss=3.05, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 32, test Loss: 3.269388508796692, test Accuracy: 0.1886000081896782\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.228, loss=3, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 33, test Loss: 3.280174207687378, test Accuracy: 0.18370001018047333\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.234, loss=3.08, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 34, test Loss: 3.247648596763611, test Accuracy: 0.19470000714063646\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.58it/s, accuracy=0.206, loss=3.11, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 35, test Loss: 3.236310625076294, test Accuracy: 0.2018000066280365\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.92it/s, accuracy=0.219, loss=3.04, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 36, test Loss: 3.2433222055435182, test Accuracy: 0.19420000910758972\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.223, loss=3.05, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 37, test Loss: 3.2257045030593874, test Accuracy: 0.19530000984668733\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.63it/s, accuracy=0.263, loss=2.96, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 38, test Loss: 3.221918821334839, test Accuracy: 0.20050000846385957\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.86it/s, accuracy=0.254, loss=2.91, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 39, test Loss: 3.230287194252014, test Accuracy: 0.20160000771284103\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.57it/s, accuracy=0.245, loss=2.96, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 40, test Loss: 3.236723041534424, test Accuracy: 0.19820000976324081\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.237, loss=2.97, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 41, test Loss: 3.231258010864258, test Accuracy: 0.20250001102685927\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.60it/s, accuracy=0.228, loss=2.95, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 42, test Loss: 3.220815086364746, test Accuracy: 0.20580001026391984\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.80it/s, accuracy=0.246, loss=2.87, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 43, test Loss: 3.2338695764541625, test Accuracy: 0.2041000097990036\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.262, loss=2.88, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 44, test Loss: 3.2135489225387572, test Accuracy: 0.20290001034736632\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.274, loss=2.84, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 45, test Loss: 3.204585647583008, test Accuracy: 0.20940001159906388\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.89it/s, accuracy=0.269, loss=2.83, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 46, test Loss: 3.1972398281097414, test Accuracy: 0.21000000983476638\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.63it/s, accuracy=0.236, loss=2.92, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 47, test Loss: 3.2239073276519776, test Accuracy: 0.20540000945329667\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.55it/s, accuracy=0.243, loss=2.93, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 48, test Loss: 3.1972729444503782, test Accuracy: 0.20940001010894777\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.62it/s, accuracy=0.255, loss=2.86, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 49, test Loss: 3.210345673561096, test Accuracy: 0.20560001134872435\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 24.92it/s, accuracy=0.262, loss=2.77, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 50, test Loss: 3.1845874547958375, test Accuracy: 0.21320001184940338\n" ] } ], "source": [ "train(model, optimizer)" ] }, { "cell_type": "code", "execution_count": 8, "id": "2ee6c1e0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top-1 Accuracy: 0.21320001184940338\n", "Top-5 Accuracy: 0.5048000246286393\n", "CNN+HQKAN Model size: 59624\n", "CNN part size: 56320\n", "HQKAN part size: 3304\n" ] } ], "source": [ "print(\"Top-1 Accuracy: \", max(test_accs))\n", "print(\"Top-5 Accuracy: \", max(test_top5_accs))\n", "\n", "print(\"CNN+HQKAN Model size:\", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))\n", "cnn = CNet(device=\"cpu\")\n", "cnn_param = len(torch.concat([param.cpu().detach().flatten() for param in cnn.parameters() if param.requires_grad]))\n", "del cnn\n", "print(\"CNN part size:\", cnn_param)\n", "print(\"HQKAN part size:\", param - cnn_param)" ] }, { "cell_type": "markdown", "id": "5538c572", "metadata": {}, "source": [ "To enhance model capacity, we increase the input and output channels of the QKAN core. This enables learning of more complex feature representations while maintaining the parameter budget via the hybrid bottleneck design." ] }, { "cell_type": "code", "execution_count": 9, "id": "25bf35f2", "metadata": {}, "outputs": [], "source": [ "in_resize = np.ceil(np.log2(in_feat)).astype(int) * 4\n", "out_resize = np.ceil(np.log2(out_feat)).astype(int) * 4\n", "model = nn.Sequential(\n", " CNet(device=device),\n", " nn.Linear(in_feat, in_resize, device=device),\n", " QKAN([in_resize, out_resize], device=device),\n", " nn.Linear(out_resize, out_feat, device=device),\n", ")\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "accs = []\n", "losses = []\n", "test_accs = []\n", "test_losses = []\n", "test_top5_accs = []\n", "\n", "optimizer = optim.Adam(model.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 10, "id": "36e3c889", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.04, loss=4.37, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, test Loss: 4.359431600570678, test Accuracy: 0.04140000194311142\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.53it/s, accuracy=0.055, loss=4.03, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2, test Loss: 4.007006192207337, test Accuracy: 0.08220000267028808\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.57it/s, accuracy=0.107, loss=3.87, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3, test Loss: 3.803478169441223, test Accuracy: 0.11150000542402268\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.61it/s, accuracy=0.149, loss=3.56, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4, test Loss: 3.689195442199707, test Accuracy: 0.13470000326633452\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.03it/s, accuracy=0.145, loss=3.6, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5, test Loss: 3.5530738115310667, test Accuracy: 0.15630000829696655\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.73it/s, accuracy=0.159, loss=3.45, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6, test Loss: 3.47371768951416, test Accuracy: 0.1645000085234642\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.71it/s, accuracy=0.194, loss=3.31, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7, test Loss: 3.369604206085205, test Accuracy: 0.18320000916719437\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.06it/s, accuracy=0.198, loss=3.27, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8, test Loss: 3.2939977645874023, test Accuracy: 0.1992000088095665\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.209, loss=3.19, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9, test Loss: 3.2363681554794312, test Accuracy: 0.20680001080036164\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.72it/s, accuracy=0.225, loss=3.19, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10, test Loss: 3.1699442625045777, test Accuracy: 0.21730001121759415\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.239, loss=3.1, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11, test Loss: 3.124575066566467, test Accuracy: 0.22790001183748246\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.10it/s, accuracy=0.254, loss=3.03, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 12, test Loss: 3.0657344341278074, test Accuracy: 0.2417000100016594\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.64it/s, accuracy=0.235, loss=2.99, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13, test Loss: 3.0330076456069945, test Accuracy: 0.24390001147985457\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.267, loss=2.93, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14, test Loss: 2.970594549179077, test Accuracy: 0.25330001413822173\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 22.92it/s, accuracy=0.267, loss=2.87, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15, test Loss: 2.9353652477264403, test Accuracy: 0.26220000833272933\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.65it/s, accuracy=0.297, loss=2.79, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16, test Loss: 2.929064393043518, test Accuracy: 0.26770001500844953\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.63it/s, accuracy=0.332, loss=2.71, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17, test Loss: 2.878641438484192, test Accuracy: 0.2755000114440918\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.68it/s, accuracy=0.315, loss=2.7, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18, test Loss: 2.865756464004517, test Accuracy: 0.2796000182628632\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.00it/s, accuracy=0.292, loss=2.73, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19, test Loss: 2.8631765127182005, test Accuracy: 0.2791000097990036\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.68it/s, accuracy=0.299, loss=2.71, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 20, test Loss: 2.8106075525283813, test Accuracy: 0.28650001883506776\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.81it/s, accuracy=0.319, loss=2.65, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 21, test Loss: 2.806682085990906, test Accuracy: 0.28940001130104065\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.12it/s, accuracy=0.335, loss=2.59, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 22, test Loss: 2.7785612821578978, test Accuracy: 0.2944000124931335\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.33, loss=2.55, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 23, test Loss: 2.745818758010864, test Accuracy: 0.3020000159740448\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.60it/s, accuracy=0.331, loss=2.55, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 24, test Loss: 2.7335641622543334, test Accuracy: 0.30610001385211943\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.52it/s, accuracy=0.327, loss=2.57, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 25, test Loss: 2.7024224758148194, test Accuracy: 0.31030001640319826\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.07it/s, accuracy=0.359, loss=2.47, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 26, test Loss: 2.6971447467803955, test Accuracy: 0.3104000151157379\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.47it/s, accuracy=0.351, loss=2.46, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 27, test Loss: 2.6772998094558718, test Accuracy: 0.31400001645088194\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.52it/s, accuracy=0.387, loss=2.35, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 28, test Loss: 2.66512246131897, test Accuracy: 0.32040001153945924\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.01it/s, accuracy=0.382, loss=2.42, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 29, test Loss: 2.667982268333435, test Accuracy: 0.3170000106096268\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.75it/s, accuracy=0.369, loss=2.41, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 30, test Loss: 2.629500079154968, test Accuracy: 0.325500014424324\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.74it/s, accuracy=0.394, loss=2.38, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 31, test Loss: 2.6221821546554565, test Accuracy: 0.3290000170469284\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.73it/s, accuracy=0.412, loss=2.31, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 32, test Loss: 2.6159011363983153, test Accuracy: 0.32920001745224\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.05it/s, accuracy=0.403, loss=2.38, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 33, test Loss: 2.601720857620239, test Accuracy: 0.33770001232624053\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.392, loss=2.32, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 34, test Loss: 2.603303813934326, test Accuracy: 0.33710001707077025\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.66it/s, accuracy=0.398, loss=2.38, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 35, test Loss: 2.6415931463241575, test Accuracy: 0.3316000163555145\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.13it/s, accuracy=0.37, loss=2.33, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 36, test Loss: 2.5769686698913574, test Accuracy: 0.3406000167131424\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.67it/s, accuracy=0.384, loss=2.36, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 37, test Loss: 2.5785749912261964, test Accuracy: 0.3406000167131424\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.70it/s, accuracy=0.405, loss=2.31, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 38, test Loss: 2.596335530281067, test Accuracy: 0.33610001504421233\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.64it/s, accuracy=0.434, loss=2.14, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 39, test Loss: 2.573742628097534, test Accuracy: 0.34020001590251925\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.05it/s, accuracy=0.419, loss=2.22, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 40, test Loss: 2.577555251121521, test Accuracy: 0.3434000164270401\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.61it/s, accuracy=0.408, loss=2.27, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 41, test Loss: 2.5758074045181276, test Accuracy: 0.33860001564025877\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.413, loss=2.2, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 42, test Loss: 2.564303970336914, test Accuracy: 0.34670001566410064\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.11it/s, accuracy=0.401, loss=2.24, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 43, test Loss: 2.563422918319702, test Accuracy: 0.3433000147342682\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.449, loss=2.14, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 44, test Loss: 2.57224817276001, test Accuracy: 0.3441000163555145\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.69it/s, accuracy=0.44, loss=2.08, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 45, test Loss: 2.5427053689956667, test Accuracy: 0.35400002002716063\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.76it/s, accuracy=0.437, loss=2.13, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 46, test Loss: 2.564718508720398, test Accuracy: 0.3476000130176544\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 22.99it/s, accuracy=0.395, loss=2.21, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 47, test Loss: 2.57286651134491, test Accuracy: 0.35210002064704893\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.53it/s, accuracy=0.423, loss=2.16, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 48, test Loss: 2.555943727493286, test Accuracy: 0.34780001938343047\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 23.55it/s, accuracy=0.454, loss=2.12, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 49, test Loss: 2.524978685379028, test Accuracy: 0.36210001111030576\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:02<00:00, 22.87it/s, accuracy=0.437, loss=2.07, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 50, test Loss: 2.5304693937301637, test Accuracy: 0.35640001893043516\n" ] } ], "source": [ "train(model, optimizer)" ] }, { "cell_type": "code", "execution_count": 11, "id": "757920c9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top-1 Accuracy: 0.36210001111030576\n", "Top-5 Accuracy: 0.6717000305652618\n", "CNN+HQKAN-44 Model size: 74612\n", "CNN part size: 56320\n", "HQKAN-44 part size: 18292\n" ] } ], "source": [ "print(\"Top-1 Accuracy: \", max(test_accs))\n", "print(\"Top-5 Accuracy: \", max(test_top5_accs))\n", "\n", "print(\"CNN+HQKAN-44 Model size:\", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))\n", "print(\"CNN part size:\", cnn_param)\n", "print(\"HQKAN-44 part size:\", param - cnn_param)" ] }, { "cell_type": "markdown", "id": "f04fb1b9", "metadata": {}, "source": [ "Compare to MLP" ] }, { "cell_type": "code", "execution_count": 12, "id": "348f892e", "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self, device):\n", " super(CNN, self).__init__()\n", " self.cnet = CNet(device)\n", " self.fc1 = nn.Linear(256, 192).to(device)\n", " self.fc2 = nn.Linear(192, 128).to(device)\n", " self.fc3 = nn.Linear(128, 100).to(device)\n", " self.device = device\n", " \n", " def forward(self, x):\n", " x.to(self.device)\n", " x = self.cnet(x)\n", " x = x.view(x.size(0), -1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.fc2(x)\n", " x = F.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "model = CNN(device=device)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "accs = []\n", "losses = []\n", "test_accs = []\n", "test_losses = []\n", "test_top5_accs = []\n", "\n", "optimizer = optim.Adam(model.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 13, "id": "bbc9384a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.49it/s, accuracy=0.073, loss=4.02, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, test Loss: 4.0597692966461185, test Accuracy: 0.07460000440478325\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.60it/s, accuracy=0.085, loss=3.85, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2, test Loss: 3.8250439167022705, test Accuracy: 0.10800000503659249\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.46it/s, accuracy=0.127, loss=3.7, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3, test Loss: 3.6398918867111205, test Accuracy: 0.14570000618696213\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.171, loss=3.47, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4, test Loss: 3.501902627944946, test Accuracy: 0.174700003862381\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.18it/s, accuracy=0.217, loss=3.27, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5, test Loss: 3.39032986164093, test Accuracy: 0.18960001021623613\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.66it/s, accuracy=0.205, loss=3.23, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6, test Loss: 3.329215979576111, test Accuracy: 0.20160000920295715\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.208, loss=3.22, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7, test Loss: 3.2377456426620483, test Accuracy: 0.21970001012086868\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.28it/s, accuracy=0.256, loss=3.09, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8, test Loss: 3.1636680603027343, test Accuracy: 0.22900001108646392\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.243, loss=3.04, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9, test Loss: 3.1480681180953978, test Accuracy: 0.2313000112771988\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.244, loss=3.1, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10, test Loss: 3.0826433897018433, test Accuracy: 0.2523000121116638\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.268, loss=3.03, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11, test Loss: 3.0320362567901613, test Accuracy: 0.25540001392364503\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.39it/s, accuracy=0.271, loss=2.9, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 12, test Loss: 3.0019490718841553, test Accuracy: 0.26240001022815707\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.59it/s, accuracy=0.274, loss=2.93, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13, test Loss: 2.9862070083618164, test Accuracy: 0.2681000143289566\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.41it/s, accuracy=0.296, loss=2.82, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14, test Loss: 2.9344271421432495, test Accuracy: 0.28090001046657564\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.28it/s, accuracy=0.297, loss=2.84, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15, test Loss: 2.9287622928619386, test Accuracy: 0.27840001285076144\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.18it/s, accuracy=0.304, loss=2.72, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16, test Loss: 2.8981085062026977, test Accuracy: 0.2864000201225281\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.292, loss=2.8, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17, test Loss: 2.864175868034363, test Accuracy: 0.2898000180721283\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.305, loss=2.71, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18, test Loss: 2.830798387527466, test Accuracy: 0.2998000144958496\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.321, loss=2.6, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19, test Loss: 2.8063244581222535, test Accuracy: 0.30420001447200773\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.61it/s, accuracy=0.319, loss=2.68, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 20, test Loss: 2.8006964921951294, test Accuracy: 0.3048000156879425\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.47it/s, accuracy=0.328, loss=2.66, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 21, test Loss: 2.769916844367981, test Accuracy: 0.3151000142097473\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.33it/s, accuracy=0.312, loss=2.62, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 22, test Loss: 2.7571829080581667, test Accuracy: 0.3123000174760818\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.36, loss=2.53, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 23, test Loss: 2.733232855796814, test Accuracy: 0.31750001311302184\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.35it/s, accuracy=0.382, loss=2.41, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 24, test Loss: 2.7194632291793823, test Accuracy: 0.3200000137090683\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.36it/s, accuracy=0.377, loss=2.47, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 25, test Loss: 2.756288003921509, test Accuracy: 0.3152000159025192\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.49it/s, accuracy=0.361, loss=2.44, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 26, test Loss: 2.691506338119507, test Accuracy: 0.3308000206947327\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.68it/s, accuracy=0.369, loss=2.46, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 27, test Loss: 2.689697813987732, test Accuracy: 0.3318000167608261\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.42it/s, accuracy=0.371, loss=2.39, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 28, test Loss: 2.6991366863250734, test Accuracy: 0.32980001270771026\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.22it/s, accuracy=0.383, loss=2.44, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 29, test Loss: 2.647859072685242, test Accuracy: 0.33790001571178435\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.376, loss=2.41, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 30, test Loss: 2.6324229001998902, test Accuracy: 0.3419000118970871\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.413, loss=2.25, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 31, test Loss: 2.6560914516448975, test Accuracy: 0.3348000138998032\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.35it/s, accuracy=0.409, loss=2.32, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 32, test Loss: 2.624714732170105, test Accuracy: 0.3481000155210495\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.358, loss=2.41, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 33, test Loss: 2.610011410713196, test Accuracy: 0.344700014591217\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.78it/s, accuracy=0.412, loss=2.26, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 34, test Loss: 2.5931995153427123, test Accuracy: 0.3505000174045563\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.40it/s, accuracy=0.406, loss=2.29, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 35, test Loss: 2.5763473510742188, test Accuracy: 0.3550000131130219\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.416, loss=2.22, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 36, test Loss: 2.576857900619507, test Accuracy: 0.3531000167131424\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.52it/s, accuracy=0.42, loss=2.17, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 37, test Loss: 2.5906506299972536, test Accuracy: 0.35740001797676085\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.44it/s, accuracy=0.416, loss=2.22, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 38, test Loss: 2.569392275810242, test Accuracy: 0.3538000226020813\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.438, loss=2.08, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 39, test Loss: 2.5566878080368043, test Accuracy: 0.35710001587867735\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.54it/s, accuracy=0.416, loss=2.19, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 40, test Loss: 2.556428050994873, test Accuracy: 0.3639000207185745\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.68it/s, accuracy=0.434, loss=2.13, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 41, test Loss: 2.542978119850159, test Accuracy: 0.36690001785755155\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.45it/s, accuracy=0.445, loss=2.12, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 42, test Loss: 2.5442334413528442, test Accuracy: 0.362700018286705\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.43it/s, accuracy=0.431, loss=2.17, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 43, test Loss: 2.5535474538803102, test Accuracy: 0.3637000173330307\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.445, loss=2.1, lr=0.001] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 44, test Loss: 2.530148911476135, test Accuracy: 0.3738000154495239\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.466, loss=2.03, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 45, test Loss: 2.527567744255066, test Accuracy: 0.36940001845359804\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.55it/s, accuracy=0.465, loss=2.01, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 46, test Loss: 2.5381455183029176, test Accuracy: 0.37100001573562624\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.53it/s, accuracy=0.488, loss=1.99, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 47, test Loss: 2.532746434211731, test Accuracy: 0.3675000160932541\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 25.78it/s, accuracy=0.494, loss=1.91, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 48, test Loss: 2.5316006898880006, test Accuracy: 0.37280001640319826\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.48it/s, accuracy=0.473, loss=1.96, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 49, test Loss: 2.561800479888916, test Accuracy: 0.3666000127792358\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:01<00:00, 26.64it/s, accuracy=0.461, loss=1.99, lr=0.001]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 50, test Loss: 2.515561318397522, test Accuracy: 0.3707000195980072\n" ] } ], "source": [ "train(model, optimizer)" ] }, { "cell_type": "code", "execution_count": 14, "id": "a918dff8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top-1 Accuracy: 0.3738000154495239\n", "Top-5 Accuracy: 0.6787000298500061\n", "CNN+MLP Model size: 143268\n", "CNN part size: 56320\n", "MLP part size: 86948\n" ] } ], "source": [ "print(\"Top-1 Accuracy: \", max(test_accs))\n", "print(\"Top-5 Accuracy: \", max(test_top5_accs))\n", "\n", "print(\"CNN+MLP Model size:\", param:=len(torch.concat([param.cpu().detach().flatten() for param in model.parameters() if param.requires_grad])))\n", "print(\"CNN part size:\", cnn_param)\n", "print(\"MLP part size:\", param - cnn_param)" ] }, { "cell_type": "markdown", "id": "b15ed96c", "metadata": {}, "source": [ "## Results\n", "\n", "- **HQKAN-44** achieves **comparable accuracy** to an MLP baseline.\n", "- It uses only **20% of the parameters** compared to the MLP.\n", "- Training time is also **comparable to MLP**, and **significantly faster** than many other quantum-inspired machine learning (QML) models when run on classical simulators.\n", "\n", "## Summary\n", "\n", "Hybrid QKAN offers a scalable and efficient alternative to traditional MLPs and other QML models:\n", "- ✔ Comparable accuracy\n", "- ✔ 5× fewer parameters\n", "- ✔ Faster training than typical QML approaches\n", "\n", "This demonstrates HQKAN's potential as a practical and effective model architecture for modern deep learning tasks." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }