{ "cells": [ { "cell_type": "markdown", "id": "e8248294", "metadata": {}, "source": [ "# KAN" ] }, { "cell_type": "markdown", "id": "91c5d5ed", "metadata": {}, "source": [ "## Kolmogorov-Arnold Representation Theorem (KART)\n", "\n", "KART promises that any multivariate continuous function $f(x_1, \\ldots, x_N)$ can be represented as a finite composition of univariate functions and addition:\n", "$$\n", " f(\\boldsymbol{x}) = \\sum_{q=1}^{2N+1} \\Phi_q \\left( \\sum_{p=1}^{N} \\phi_{q,p}(x_p) \\right),\n", "$$\n", "where $\\phi_{q,p} : [0,1] \\rightarrow \\mathbb{R}$ and $\\Phi_q : \\mathbb{R} \\rightarrow \\mathbb{R}$ are continuous functions.\n", "Though KART provides theoretical guarantees of universal approximation, the inner and outer functions can be non-smooth or hard to learn in practice.\n", "We need to find a better way to approximate these functions with smooth and learnable ones." ] }, { "cell_type": "markdown", "id": "3d784dc4", "metadata": {}, "source": [ "## Kolmogorov-Arnold Network (KAN)\n", "\n", "Liu et al. introduced KANs as a practical realization of the KART, generalizing it to deep and wide architectures.\n", "Each variational activation function in KANs is modeled as a learnable function parameterized by B-splines, which are piecewise polynomial functions capable of approximating any continuous function with arbitrary precision.\n", "\n", "Formally, a KAN layer maps the output of the $\\ell$-th layer to the $(\\ell+1)$-th layer via:\n", "$$\n", " x_{\\ell+1,j} = \\sum_{i=1}^{n_\\ell} \\phi_{\\ell,j,i}(x_{\\ell,i}),\n", "$$\n", "where $\\phi_{\\ell,j,i}$ is the learnable univariate variational activation function connecting input node $i$ to output node $j$.\n", "This can be expressed in matrix notation as:\n", "$$\n", "\\begin{align}\n", " \\boldsymbol{x}_{l+1} & = \\Phi_\\ell(\\boldsymbol{x}_\\ell), \\\\\n", " \\Phi_\\ell & = \\begin{pmatrix}\n", " \\phi_{\\ell,1,1}(\\cdot) & \\phi_{\\ell,1,2}(\\cdot) & \\cdots & \\phi_{\\ell,1,n_\\ell}(\\cdot) \\\\\n", " \\phi_{\\ell,2,1}(\\cdot) & \\phi_{\\ell,2,2}(\\cdot) & \\cdots & \\phi_{\\ell,2,n_\\ell}(\\cdot) \\\\\n", " \\vdots & \\vdots & \\ddots & \\vdots \\\\\n", " \\phi_{\\ell,n_{\\ell+1},1}(\\cdot) & \\phi_{\\ell,n_{\\ell+1},2}(\\cdot) & \\cdots & \\phi_{\\ell,n_{\\ell+1},n_\\ell}(\\cdot)\n", " \\end{pmatrix}.\n", "\\end{align}\n", "$$\n", "\n", "The KAN is a composition of $L$ KAN layers: given input $\\boldsymbol{x}$, we’ll have KAN output as:\n", "$$\n", "\\text{KAN}(\\boldsymbol{x}) = \\Phi_{L-1}\\circ\\Phi_{L-2}\\circ\\cdots\\circ\\Phi_1\\circ\\Phi_0(\\boldsymbol{x}).\n", "$$\n", "By contrast, a multilayer perceptron (MLP) is given by linear layers $W$ and nonlinear activation functions $\\sigma$ as:\n", "$$\n", "\\text{MLP}(\\boldsymbol{x}) = \\sigma(W_{L-1}\\circ W_{L-2}\\circ\\cdots\\circ W_1\\circ W_0)(\\boldsymbol{x}).\n", "$$\n", "\n", "The following figure illustrates the KAN architecture, where each node (neuron scheme) in the network represents a variational activation function $\\phi_{\\ell,j,i}$ that connects inputs to outputs across layers.\n", "![KAN](../fig/kan.png)" ] }, { "cell_type": "markdown", "id": "9d05257e", "metadata": {}, "source": [ "## Training KAN\n", "\n", "Here is a simple example of how to train a KAN from QKAN package:" ] }, { "cell_type": "code", "execution_count": 1, "id": "32b4b09a", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from tqdm import tqdm\n", "\n", "from qkan import KAN, create_dataset\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "f = lambda x: torch.sin(20 * x) / x / 20 # J_0(20x)\n", "dataset = create_dataset(\n", " f, n_var=1, ranges=[0, 1], device=device, train_num=1000, test_num=1000, seed=0\n", ")\n", "\n", "model = KAN(\n", " [1, 1],\n", " grid_size=10,\n", " device=device,\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)\n", "loss_fn = torch.nn.MSELoss()" ] }, { "cell_type": "code", "execution_count": 2, "id": "79d00abf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0031: 100%|██████████████████████████████████████████████| 100/100 [00:00<00:00, 735.03it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test loss: 0.0032\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "steps = 100\n", "pbar = tqdm(range(steps), ncols=100)\n", "\n", "model.train()\n", "for _ in pbar:\n", " optimizer.zero_grad()\n", " pred = model.forward(dataset[\"train_input\"])\n", " loss = loss_fn(pred, dataset[\"train_label\"])\n", " loss.backward()\n", " optimizer.step()\n", "\n", " pbar.set_description(f\"loss: {loss.item():.4f}\")\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " pred = model.forward(dataset[\"test_input\"])\n", " loss = loss_fn(pred, dataset[\"test_label\"])\n", " print(f\"Test loss: {loss.item():.4f}\")" ] }, { "cell_type": "markdown", "id": "63eb61ec", "metadata": {}, "source": [ "###### Further Reading\n", "\n", "To understand KANs in more detail, the reader can refer to the original paper: \"[KAN: Kolmogorov-Arnold Networks](https://arxiv.org/abs/2404.19756)\".\n", "\n", "And `pykan` documentation: https://kindxiaoming.github.io/pykan/index.html" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }