{ "cells": [ { "cell_type": "markdown", "id": "76939333", "metadata": {}, "source": [ "# Layer Extension\n", "\n", "In this example, we will demonstrate how to extend layers in data re-uploading circuits and making a more powerful DARUAN in QKAN." ] }, { "cell_type": "code", "execution_count": 1, "id": "987fb07e", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "from qkan import QKAN, create_dataset\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# f(x,y) = exp(sin(pi*x)+y^2)\n", "f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2)\n", "dataset = create_dataset(f, n_var=2, device=device)\n", "\n", "# initialize QKAN with r=1\n", "model = QKAN(\n", " [2, 5, 1],\n", " reps=1,\n", " device=device,\n", " preact_trainable=True, # enable flexible fourier frequency\n", " postact_bias_trainable=True, # extend output bound\n", " postact_weight_trainable=True, # extend output bound\n", " ba_trainable=True, # enable residual connection for better convergence\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=5e-3) " ] }, { "cell_type": "code", "execution_count": 2, "id": "e06a1428", "metadata": {}, "outputs": [], "source": [ "test_results = []\n", "qkans = [model] # save the model" ] }, { "cell_type": "code", "execution_count": 3, "id": "3b974558", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████| 100/100 [00:00<00:00, 181.69it/s, train loss=1.0124128, test loss=1.1017513]\n" ] } ], "source": [ "result = model.train_(dataset, optimizer=optimizer, steps=100)\n", "test_results += result[\"test_loss\"]" ] }, { "cell_type": "markdown", "id": "abf9a86a", "metadata": {}, "source": [ "Do layer extension to get a fine-grained model." ] }, { "cell_type": "code", "execution_count": 4, "id": "36ad25ac", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████| 100/100 [00:09<00:00, 10.89it/s, train loss=8.0940447e-07, test loss=9.647775e-07]\n", "100%|███████████| 100/100 [00:03<00:00, 27.13it/s, train loss=3.3596322e-07, test loss=3.903716e-07]\n", "100%|███████████| 100/100 [00:05<00:00, 19.77it/s, train loss=9.609832e-08, test loss=1.2657244e-07]\n", "100%|████████████| 100/100 [00:03<00:00, 28.59it/s, train loss=9.314607e-08, test loss=1.240494e-07]\n" ] } ], "source": [ "reps = [5 * i for i in range(1, 5)]\n", "for idx, r in enumerate(reps):\n", " qkans.append(\n", " QKAN(\n", " [2, 5, 1],\n", " reps=r,\n", " device=device,\n", " preact_trainable=True,\n", " postact_bias_trainable=True,\n", " postact_weight_trainable=True,\n", " ba_trainable=True,\n", " )\n", " )\n", " qkans[-1].initialize_from_another_model(qkans[idx])\n", " optimizer = torch.optim.LBFGS(qkans[-1].parameters(), lr=5e-1)\n", " result = qkans[-1].train_(dataset, optimizer=optimizer, steps=100)\n", " test_results += result[\"test_loss\"]" ] }, { "cell_type": "markdown", "id": "24bbc7b2", "metadata": {}, "source": [ "Compare to directly train a large number of repetitions." ] }, { "cell_type": "code", "execution_count": 5, "id": "192e57b9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████| 500/500 [00:49<00:00, 10.07it/s, train loss=4.4905502e-07, test loss=7.657976e-07]\n" ] } ], "source": [ "model = QKAN(\n", " [2, 5, 1],\n", " reps=20,\n", " device=device,\n", " preact_trainable=True,\n", " postact_bias_trainable=True,\n", " postact_weight_trainable=True,\n", " ba_trainable=True,\n", ")\n", "optimizer = torch.optim.LBFGS(model.parameters(), lr=5e-1)\n", "result = model.train_(dataset, optimizer=optimizer, steps=500)" ] }, { "cell_type": "code", "execution_count": 6, "id": "c2cab6c9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(test_results, label=\"layer extension\")\n", "plt.plot(result[\"test_loss\"], label=\"direct train\")\n", "plt.ylabel(\"MSE\")\n", "plt.xlabel(\"step\")\n", "plt.yscale(\"log\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "id": "499a03ed", "metadata": {}, "source": [ "With layer extension, the model achieves better loss performance while requiring less training time, enabling efficient scalability without sacrificing accuracy." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }