{ "cells": [ { "cell_type": "markdown", "id": "9719678d", "metadata": {}, "source": [ "# Transferring from QKAN to KAN\n", "\n", "In the paper, we also introduce a method to transfer the learned parameters from QKAN to KAN. This is useful when you want to use the learned parameters in a different environment.\n", "\n", "On the other hand, it is also possible to showcase that QKAN may have better learning landscape in optimizing the parameters than KAN. If the transferred KAN can achieve better performance than the pure KAN, it is a strong evidence that QKAN has better learning landscape.\n", "\n", "Here is a simple example of how to transfer the learned parameters from QKAN to KAN:" ] }, { "cell_type": "code", "execution_count": 1, "id": "df6aed0c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from tqdm import tqdm\n", "\n", "from qkan import KAN, QKAN, create_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "e4928d84", "metadata": {}, "outputs": [], "source": [ "# create dataset\n", "f = lambda x: torch.sin(torch.exp(x[:, [0]]) + x[:, [1]] ** 2)\n", "dataset = create_dataset(\n", " f,\n", " n_var=2,\n", " train_num=10000,\n", " test_num=1000,\n", ")\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "f80d95d9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████| 500/500 [00:05<00:00, 86.70it/s, train loss=0.0025606856, test loss=0.0022567876]\n", "train loss: 8.97e-08 | test loss: 1.56e-08: 100%|██████████████| 1000/1000 [00:02<00:00, 488.57it/s]\n" ] } ], "source": [ "qkan_model = QKAN(\n", " [2, 1, 1],\n", " num_qlayers=3,\n", " solver=\"exact\",\n", " device=device,\n", " preact_trainable=True,\n", " postact_bias_trainable=False,\n", " postact_weight_trainable=False,\n", " ba_trainable=True,\n", " seed=0,\n", ")\n", "optimizer = optim.Adam(qkan_model.parameters(), lr=1e-2)\n", "result = qkan_model.train_(dataset, steps=500, optimizer=optimizer)\n", "\n", "kan_model = KAN([2, 1, 1], grid_size=5, seed=0).to(device)\n", "x0 = torch.stack(\n", " [torch.linspace(-1, 1, steps=100, device=device) for _ in range(2)]\n", ").permute(1, 0) # x.shape = (sampling, in_dim)\n", "kan_model.initialize_from_qkan(qkan_model, x0=x0, sampling=100)\n", "\n", "optimizer = optim.Adam(kan_model.parameters(), lr=1e-2)\n", "pbar = tqdm(range(1000), desc=\"KAN\", ncols=100)\n", "batch_size = dataset[\"train_input\"].shape[0]\n", "batch_size_test = dataset[\"test_input\"].shape[0]\n", "criterion = loss_fn_eval = nn.MSELoss()\n", "for _ in pbar:\n", " train_id = np.random.choice(\n", " dataset[\"train_input\"].shape[0], batch_size, replace=False\n", " )\n", " test_id = np.random.choice(\n", " dataset[\"test_input\"].shape[0], batch_size_test, replace=False\n", " )\n", " pred = kan_model.forward(dataset[\"train_input\"][train_id].to(device))\n", " loss = criterion(pred, dataset[\"train_label\"][train_id].to(device))\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " test_loss = loss_fn_eval(\n", " kan_model.forward(dataset[\"test_input\"][test_id].to(device)),\n", " dataset[\"test_label\"][test_id].to(device),\n", " )\n", " result[\"train_loss\"].append(loss.cpu().detach().numpy())\n", " result[\"test_loss\"].append(test_loss.cpu().detach().numpy())\n", " pbar.set_description(\n", " \"train loss: %.2e | test loss: %.2e\"\n", " % (\n", " loss.cpu().detach().numpy(),\n", " test_loss.cpu().detach().numpy(),\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "id": "63f0f432", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "train loss: 2.00e-02 | test loss: 2.07e-02: 100%|██████████████| 1500/1500 [00:03<00:00, 496.32it/s]\n" ] } ], "source": [ "kan_model2 = KAN([2, 1, 1], grid_size=5, seed=0).to(device)\n", "\n", "# Define loss\n", "criterion = loss_fn_eval = nn.MSELoss()\n", "result2 = {}\n", "result2[\"train_loss\"] = []\n", "result2[\"test_loss\"] = []\n", "\n", "steps = 1500\n", "pbar = tqdm(range(steps), desc=\"description\", ncols=100)\n", "batch_size = dataset[\"train_input\"].shape[0]\n", "batch_size_test = dataset[\"test_input\"].shape[0]\n", "\n", "optimizer = optim.Adam(kan_model2.parameters(), lr=1e-2)\n", "for _ in pbar:\n", " train_id = np.random.choice(\n", " dataset[\"train_input\"].shape[0], batch_size, replace=False\n", " )\n", " test_id = np.random.choice(\n", " dataset[\"test_input\"].shape[0], batch_size_test, replace=False\n", " )\n", " pred = kan_model2.forward(dataset[\"train_input\"][train_id].to(device))\n", " loss = criterion(pred, dataset[\"train_label\"][train_id].to(device))\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " test_loss = loss_fn_eval(\n", " kan_model2.forward(dataset[\"test_input\"][test_id].to(device)),\n", " dataset[\"test_label\"][test_id].to(device),\n", " )\n", " result2[\"train_loss\"].append(loss.cpu().detach().numpy())\n", " result2[\"test_loss\"].append(test_loss.cpu().detach().numpy())\n", " pbar.set_description(\n", " \"train loss: %.2e | test loss: %.2e\"\n", " % (\n", " loss.cpu().detach().numpy(),\n", " test_loss.cpu().detach().numpy(),\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 5, "id": "89cd15d8", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "ax = plt.subplot(111)\n", "\n", "ax.plot(result[\"test_loss\"], label=r\"QKAN $\\rightarrow$ KAN\")\n", "ax.plot(result2[\"test_loss\"], label=\"KAN\")\n", "ax.legend()\n", "ax.set_xlabel(\"Epoch\")\n", "ax.set_ylabel(\"MSE loss\")\n", "ax.set_title(r\"Fitting $f(x,y) = \\sin{(e^x+y^2)}$\"+ \"\\nTest loss of transfer learning from QKAN to KAN\")\n", "ax.set_yscale(\"log\")" ] }, { "cell_type": "markdown", "id": "d0b61c61", "metadata": {}, "source": [ "From the results, we can see that the transferred KAN achieves better performance than the pure KAN trained from scratch.\n", "\n", "Which provides a strong support for the claim that QKAN/DARUAN or QVAF has better learning landscape than classical VAF." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }