"""
Quantum-inspired Kolmogorov Arnold Networks (QKANs) implementation in PyTorch.
Paper: Quantum Variational Activation Functions Empower Kolmogorov-Arnold Networks: https://arxiv.org/abs/2509.14026
Supported solvers:
- PennyLane
- Exact solver implemented in PyTorch (faster)
- Custom solvers api
Code author: Jiun-Cheng Jiang (Jim137@GitHub)
Contact: [jcjiang@phys.ntu.edu.tw](mailto:jcjiang@phys.ntu.edu.tw)
"""
import os
import random
import warnings
from copy import deepcopy
from glob import glob
from typing import Callable, Literal, Optional, Union
import matplotlib.pyplot as plt # type: ignore
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm # type: ignore
from .solver import qml_solver, torch_exact_solver
[docs]
class QKANLayer(nn.Module):
"""
QKANLayer Class
Attributes
----------
in_dim : int
Input dimension
out_dim : int
Output dimension
reps : int
Repetitions of quantum layers
group : int
Group of neurons
device :
Device to use
solver : Union[Literal["qml", "exact"], Callable]
Solver to use
ansatz : Union[str, Callable]
Ansatz to use, "pz_encoding", "px_encoding", "rpz_encoding" or custom
qml_device : str
PennyLane device to use
theta : nn.Parameter
Learnable parameter of quantum circuit
base_weight : nn.Parameter
Learnable parameter of base activation
preact_trainable : bool
Whether preact weights are trainable
preacts_weight : nn.Parameter
Learnable parameter of preact weights
preacts_bias : nn.Parameter
Learnable parameter of preact bias
postact_weight_trainable : bool
Whether postact weights are trainable
postact_weights : nn.Parameter
Learnable parameter of postact weights
postact_bias_trainable : bool
Whether postact bias are trainable
postact_bias : nn.Parameter
Learnable parameter of postact bias
mask : nn.Parameter
Mask for pruning
is_batchnorm : bool
Whether to use batch normalization
fast_measure : bool
Enable to use fast measurement in exact solver. Which would be quantum-inspired method.
When False, the exact solver simulates the exact measurement process of quantum circuit.
_x0 : Optional[torch.Tensor]
Leave for ResQKANLayer
"""
def __init__(
self,
in_dim: int,
out_dim: int,
reps: int = 3,
group: Union[int, tuple] = -1,
device="cpu",
solver: Union[Literal["qml", "exact"], Callable] = "exact",
qml_device="default.qubit",
ansatz: Union[str, Callable] = "pz_encoding",
theta_size: Optional[list[int]] = None,
preact_trainable: bool = False,
preact_init: bool = False,
postact_weight_trainable: bool = False,
postact_bias_trainable: bool = False,
base_activation=torch.nn.SiLU(),
ba_trainable: bool = True,
is_batchnorm: bool = False,
fast_measure: bool = True,
seed=0,
):
super(QKANLayer, self).__init__()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if isinstance(group, int):
if group == -1:
group = (out_dim, in_dim)
else:
group = tuple([group])
self.in_dim = in_dim
self.out_dim = out_dim
self.reps = reps
self.group = group
self.device = device
self.solver: Union[Literal["qml", "exact"], Callable] = solver
self.qml_device = qml_device
self.ansatz = ansatz
self.theta_size = theta_size
self.base_activation = base_activation
self.ba_trainable = ba_trainable
self.is_batchnorm = is_batchnorm
self.fast_measure = fast_measure
self.seed = seed
if callable("solver") or callable("ansatz"):
if not theta_size:
raise ValueError("theta_size is required for custom ansatz")
self.theta = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*theta_size, device=device))
)
elif ansatz == "pz_encoding":
self.theta = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*group, reps + 1, 2, device=device))
)
elif ansatz == "rpz_encoding":
if not preact_trainable:
warnings.warn(
"Reduced pz encoding requires preact_trainable=True, set automatically."
)
preact_trainable = True
self.theta = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*group, reps + 1, 1, device=device))
)
elif ansatz == "px_encoding":
self.theta = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*group, reps + 1, 1, device=device))
)
else:
raise NotImplementedError()
if ba_trainable:
self.base_weight = torch.nn.Parameter(
0.5 * torch.ones(out_dim, in_dim, device=device),
requires_grad=ba_trainable,
)
else:
self.base_weight = torch.nn.Parameter(
torch.zeros(out_dim, in_dim, device=device), requires_grad=ba_trainable
)
self.preact_trainable = preact_trainable
if not preact_init:
self.preacts_weight = nn.Parameter(
torch.ones(*group, reps, device=device),
requires_grad=preact_trainable,
)
self.preacts_bias = nn.Parameter(
torch.zeros(*group, reps, device=device),
requires_grad=preact_trainable,
)
else:
self.preacts_weight = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*group, reps, device=device)),
requires_grad=preact_trainable,
)
self.preacts_bias = nn.Parameter(
nn.init.xavier_normal_(torch.empty(*group, reps, device=device)),
requires_grad=preact_trainable,
)
self.preact_init = preact_init
self.postact_weight_trainable = postact_weight_trainable
self.postact_weights = nn.Parameter(
torch.ones(out_dim, in_dim, device=device),
requires_grad=postact_weight_trainable,
)
self.postact_bias_trainable = postact_bias_trainable
self.postact_bias = nn.Parameter(
torch.zeros(out_dim, in_dim, device=device),
requires_grad=postact_bias_trainable,
)
self.mask = nn.Parameter(
torch.ones(out_dim, in_dim, device=device), requires_grad=False
)
if is_batchnorm:
self.bn = nn.BatchNorm1d(in_dim, device=device)
self._x0: Optional[torch.Tensor] = None
[docs]
def to(self, *args, **kwargs):
"""
Move the layer to the specified device.
Args
----
device : str | torch.device
Device to move the layer to, default: "cpu"
"""
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, torch.device):
device = arg
break
if "device" in kwargs:
device = kwargs["device"]
if device:
self.device = device
for param in self.parameters():
param.data = param.to(device)
return super(QKANLayer, self).to(*args, **kwargs)
@property
def param_size(self):
if hasattr(self, "_param_size"):
return self._param_size
count = 0
for param in self.parameters():
if param.requires_grad:
count += param.numel()
self._param_size = count
return self._param_size
@property
def x0(self):
return self._x0
@x0.setter
def x0(self, x: torch.Tensor):
self._x0 = None
[docs]
def forward(self, x: torch.Tensor):
assert x.shape[1] == self.in_dim, "Invalid input dimension"
batch = x.shape[0]
if self.is_batchnorm:
x = self.bn(x)
base_output = torch.einsum(
"oi,bi->boi", self.base_weight, self.base_activation(x)
)
if self.solver == "qml":
postacts = torch.zeros(batch, self.out_dim, self.in_dim).to(self.device)
for j in range(self.out_dim):
for i in range(self.in_dim):
postacts[:, j, i] = qml_solver(
x=x[:, i],
theta=self.theta[i, j],
reps=self.reps,
device=self.device,
qml_device=self.qml_device,
)
elif self.solver == "exact":
postacts = torch_exact_solver(
x,
self.theta,
self.preacts_weight,
self.preacts_bias,
self.reps,
device=self.device,
ansatz=self.ansatz,
group=self.group,
preacts_trainable=self.preact_trainable,
fast_measure=self.fast_measure,
)
elif callable(self.solver):
postacts = self.solver(
x,
self.theta,
self.preacts_weight,
self.preacts_bias,
self.reps,
device=self.device,
ansatz=self.ansatz,
)
else:
raise NotImplementedError()
if postacts.shape[1] != self.out_dim:
postacts = postacts.expand(-1, self.out_dim, -1)
x = torch.sum(
(
(postacts + self.postact_bias) * self.postact_weights[None, :, :]
+ base_output
)
* self.mask[None, :, :],
dim=2,
)
return x
[docs]
def reset_parameters(self):
self.theta.data.copy_(torch.zeros(self.theta.shape))
[docs]
@torch.no_grad()
def forward_no_sum(self, x: torch.Tensor):
assert x.shape[1] == self.in_dim, "Invalid input dimension"
base_output = torch.einsum(
"oi,bi->boi", self.base_weight, self.base_activation(x)
)
if self.solver == "qml":
postacts = torch.cat(
[
torch.stack(
[
qml_solver(
x=x[:, i],
theta=self.theta[i, j],
reps=self.reps,
device=self.device,
qml_device=self.qml_device,
)
for i in range(self.in_dim)
],
)
.unsqueeze(-1)
.permute(1, 2, 0)
for j in range(self.out_dim)
],
dim=1,
).to(torch.float32)
elif self.solver == "exact":
postacts = torch_exact_solver(
x,
self.theta,
self.preacts_weight,
self.preacts_bias,
self.reps,
device=self.device,
ansatz=self.ansatz,
group=self.group,
preacts_trainable=self.preact_trainable,
fast_measure=self.fast_measure,
)
else:
raise NotImplementedError()
x_new = (
(postacts + self.postact_bias) * self.postact_weights[None, :, :]
+ base_output
) * self.mask[None, :, :]
return x_new
[docs]
def get_subset(self, in_id, out_id):
"""
Get a smaller QKANLayer from a larger QKANLayer (used for pruning).
Args
----
in_id : list
id of selected input neurons
out_id : list
id of selected output neurons
Returns
-------
QKANLayer
New QKANLayer with selected neurons
"""
spb = QKANLayer(
in_dim=len(in_id),
out_dim=len(out_id),
reps=self.reps,
device=self.device,
solver=self.solver,
qml_device=self.qml_device,
ansatz=self.ansatz,
preact_trainable=self.preact_trainable,
postact_weight_trainable=self.postact_weight_trainable,
postact_bias_trainable=self.postact_bias_trainable,
base_activation=self.base_activation,
ba_trainable=self.ba_trainable,
seed=self.seed,
)
spb.theta.data = self.theta[out_id][:, in_id]
spb.base_weight.data = self.base_weight[out_id][:, in_id]
spb.preacts_weight.data = self.preacts_weight[out_id][:, in_id]
spb.preacts_bias.data = self.preacts_bias[out_id][:, in_id]
spb.postact_weights.data = self.postact_weights[out_id][:, in_id]
spb.postact_bias.data = self.postact_bias[out_id][:, in_id]
spb.mask.data = self.mask[out_id][:, in_id]
return spb
class QKANModuleList(nn.ModuleList):
def __init__(self):
super(QKANModuleList, self).__init__()
# make type hint for getitem method
def __getitem__(self, idx) -> Union[QKANLayer, nn.Linear, "QKANModuleList"]:
return super(QKANModuleList, self).__getitem__(idx)
[docs]
class QKAN(nn.Module):
"""
Quantum-inspired Kolmogorov Arnold Network (QKAN) Class
A quantum-inspired neural network that uses DatA Re-Uploading ActivatioN (DARUAN)
as its learnable variation activation function.
References:
Quantum Variational Activation Functions Empower Kolmogorov-Arnold Networks: https://arxiv.org/abs/2509.14026
Attributes
----------
width : list[int]
List of width of each layer
reps : int
Repetitions of quantum layers
group : int
Group of neurons
device : Literal["cpu", "cuda"]
Device to use
solver : Literal["qml", "exact"]
Solver to use
qml_device : str
PennyLane device to use
layers : QKANModuleList
List of layers
is_map : bool
Whether to use map layer
is_batchnorm : bool
Whether to use batch normalization
reps : int
Repetitions of quantum layers
norm_out : int
Normalize output
postact_weight_trainable : bool
Whether postact weights are trainable
postact_bias_trainable : bool
Whether postact bias are trainable
preact_trainable : bool
Whether preact weights are trainable
base_activation : torch.nn.Module or lambda function
Base activation function
ba_trainable : bool
Whether base activation weights are trainable
fast_measure : bool
Enable to use fast measurement in exact solver. Which would be quantum-inspired method.
When False, the exact solver simulates the exact measurement process of quantum circuit.
save_act : bool
Whether to save activations
seed : int
Random seed
"""
def __init__(
self,
width: list[int],
reps: int = 3,
group: int = -1,
is_map: bool = False,
is_batchnorm: bool = False,
hidden: int = 0,
device="cpu",
solver: Union[Literal["qml", "exact"], Callable] = "exact",
qml_device: str = "default.qubit",
ansatz: Union[str, Callable] = "pz_encoding",
norm_out: int = 0,
preact_trainable: bool = False,
preact_init: bool = False,
postact_weight_trainable: bool = False,
postact_bias_trainable: bool = False,
base_activation=nn.SiLU(),
ba_trainable: bool = False,
fast_measure: bool = True,
save_act: bool = False,
seed=0,
**kwargs,
):
"""
Initialize QKAN model
Args
----
width : list[int]
List of width of each layer
reps : int
Repetitions of quantum layers, default: 3
group : int
Group of neurons, default: -1
is_map : bool
Whether to use map layer, default: False
is_batchnorm: bool
Whether to add a batchnorm layer before QKANLayer, default: False
hidden : int
Number of hidden units in map layer, default: 0
device :
Device to use, default: "cpu"
solver : Union[Literal["qml", "exact"], Callable]
Solver to use, default: "exact"
ansatz : Union[str, Callable]
Ansatz to use, "pz_encoding", "px_encoding", "rpz_encoding" or custom
qml_device : str
PennyLane device to use, default: "default.qubit"
ansatz : str | Callable
Ansatz to use, default: "pz_encoding"
norm_out : int
Normalize output, default: 0
postact_weight_trainable : bool
Whether postact weights are trainable, default: False
postact_bias_trainable : bool
Whether postact bias are trainable, default: False
base_activation : torch.nn.Module | lambda function
Base activation function, default: torch.nn.SiLU()
ba_trainable : bool
Whether base activation weights are trainable, default: False
save_act : bool
Whether to save activations, default: False
fast_measure : bool
Enable to use fast measurement in exact solver. Which would be quantum-inspired method.
When False, the exact solver simulates the exact measurement process of quantum circuit.
seed : int
Random seed, default: 0
"""
super(QKAN, self).__init__()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
self.depth = len(width) - (2 if is_map else 1)
self.width = width
self.reps = reps
self.group = group
self.device = device
self.solver: Union[Literal["qml", "exact"], Callable] = solver
self.ansatz = ansatz
self.qml_device = qml_device
self.norm_out = norm_out
self.postact_weight_trainable = postact_weight_trainable
self.postact_bias_trainable = postact_bias_trainable
self.preact_trainable = preact_trainable
self.preact_init = preact_init
self.base_activation = base_activation
self.ba_trainable = ba_trainable
self.fast_measure = fast_measure
self.save_act = save_act
self.seed = seed
self.layers = QKANModuleList()
for l in range(self.depth):
self.layers.append(
QKANLayer(
in_dim=width[l],
out_dim=width[l + 1],
reps=reps,
group=group,
device=self.device,
solver=self.solver,
qml_device=self.qml_device,
ansatz=self.ansatz,
preact_trainable=preact_trainable,
preact_init=preact_init,
postact_weight_trainable=postact_weight_trainable,
postact_bias_trainable=postact_bias_trainable,
base_activation=base_activation,
ba_trainable=ba_trainable,
is_batchnorm=is_batchnorm,
fast_measure=fast_measure,
seed=seed,
)
)
self.is_batchnorm = is_batchnorm
self.is_map = is_map
self.hidden = hidden
if is_map:
self.layers.append(nn.Linear(width[-2], hidden, device=self.device))
self.layers.append(nn.SiLU())
self.layers.append(nn.Linear(hidden, width[-1], device=self.device))
self.input_id: Optional[torch.Tensor] = None
[docs]
def to(self, *args, **kwargs):
"""
Move the model to the specified device.
Args
----
device : str | torch.device
Device to move the model to, default: "cpu"
"""
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, torch.device):
device = arg
break
if "device" in kwargs:
device = kwargs["device"]
if device:
self.device = device
for layer in self.layers:
layer.to(*args, **kwargs)
return super(QKAN, self).to(*args, **kwargs)
@property
def param_size(self):
if hasattr(self, "_param_size"):
return self._param_size
count = 0
for layer in self.layers:
if not isinstance(layer, QKANLayer):
count += sum(p.numel() for p in layer.parameters())
continue
count += layer.param_size
self._param_size = count
return self._param_size
[docs]
def forward(self, x: torch.Tensor):
shape_size = len(x.shape)
if shape_size == 3:
B, C, T = x.shape
elif shape_size == 2:
B, T = x.shape
else:
raise NotImplementedError()
x = x.view(-1, T)
if self.input_id is not None:
x = x[:, self.input_id.long()]
if self.save_act:
self.cache_data = x
self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
self.subnode_actscale = []
self.dr_preacts = []
self.dr_postacts = []
self.acts_scale = []
self.acts_scale_dr = []
self.edge_actscale = []
self.acts.append(x.clone())
for layer in self.layers:
if self.save_act and isinstance(layer, QKANLayer):
self.subnode_actscale.append(torch.std(x, dim=0).detach())
preacts = (x.clone())[:, None, :].expand(B, layer.out_dim, layer.in_dim)
postacts = layer.forward_no_sum(x) # shape: (batch, out_dim, in_dim)
x = layer(x)
if self.save_act and isinstance(layer, QKANLayer):
input_range = torch.std(preacts, dim=0) + 0.1
output_range_dr = torch.std(
postacts, dim=0
) # for training, only penalize the dr part
output_range = torch.std(
postacts, dim=0
) # leave for symbolic (Not implemented yet)
# save edge_scale
self.edge_actscale.append(output_range)
self.acts_scale.append((output_range / input_range).detach())
self.acts_scale_dr.append(output_range_dr / input_range)
self.dr_preacts.append(preacts.detach())
self.dr_postacts.append(postacts.detach())
self.acts.append(x.detach())
if self.norm_out:
x = F.normalize(x, p=self.norm_out, dim=1)
U = x.shape[1]
if shape_size == 3:
x = x.view(B, C, U)
elif shape_size == 2:
assert x.shape == (B, U)
return x
[docs]
def initialize_from_another_model(self, another_model: "QKAN"):
"""
Initialize from another model.
Used for layer extension to refine the model.
Args
----
another_model : QKAN
Another model to initialize from
"""
assert all(x == y for x, y in zip(self.width, another_model.width)), (
"Cannot initialize from another model with different width"
)
count = -2
for l, layer in enumerate(self.layers):
if isinstance(layer, QKANLayer):
layer.reset_parameters()
for i in range(another_model.layers[l].reps):
layer.theta.data[:, :, i, :].copy_(
another_model.layers[l].theta.data[:, :, i, :]
)
layer.preacts_weight.data[:, :, i].copy_(
another_model.layers[l].preacts_weight.data[:, :, i]
)
layer.preacts_bias.data[:, :, i].copy_(
another_model.layers[l].preacts_bias.data[:, :, i]
)
layer.theta.data[:, :, another_model.layers[l].reps, :].copy_(
another_model.layers[l].theta.data[
:, :, another_model.layers[l].reps, :
]
)
layer.postact_weights.data.copy_(
another_model.layers[l].postact_weights.data
)
layer.postact_bias.data.copy_(another_model.layers[l].postact_bias.data)
layer.base_weight.data.copy_(another_model.layers[l].base_weight.data)
if isinstance(layer, nn.Linear):
layer.weight.data.copy_(another_model.layers[count - 1].weight.data)
layer.bias.data.copy_(another_model.layers[count - 1].bias.data)
count += 2
return self
def _reg(
self,
reg_metric: str,
lamb_l1: float,
lamb_entropy: float,
lamb_coef: float,
lamb_coefdiff: float,
):
"""
Get regularization.
Adapted from "pykan".
Args
----
reg_metric : the regularization metric
'edge_forward_dr_n', 'edge_forward_dr_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
lamb_l1 : float
l1 penalty strength
lamb_entropy : float
entropy penalty strength
lamb_coef : float
coefficient penalty strength
lamb_coefdiff : float
coefficient smoothness strength
Returns
-------
torch.Tensor
"""
if reg_metric == "edge_forward_dr_n":
acts_scale = self.acts_scale_dr
elif reg_metric == "edge_forward_sum":
acts_scale = self.acts_scale
elif reg_metric == "edge_forward_dr_u":
acts_scale = self.edge_actscale
elif reg_metric == "edge_backward":
acts_scale = self.edge_scores
elif reg_metric == "node_backward":
acts_scale = self.node_attribute_scores
else:
raise RuntimeError(f"reg_metric = {reg_metric} not recognized!")
reg_: torch.Tensor = torch.tensor(0.0, device=self.device)
for i in range(len(acts_scale)):
vec = acts_scale[i]
l1 = torch.sum(vec)
p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
entropy_row = -torch.mean(
torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1)
)
entropy_col = -torch.mean(
torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0)
)
reg_ += lamb_l1 * l1 + lamb_entropy * (
entropy_row + entropy_col
) # both l1 and entropy
# regularize coefficient to encourage activation to be zero
for layer in self.layers:
if not isinstance(layer, QKANLayer):
continue
coeff_l1 = torch.sum(torch.mean(torch.abs(layer.postact_weights), dim=1))
coeff_diff_l1 = torch.sum(
torch.mean(torch.abs(torch.diff(layer.postact_weights)), dim=1)
)
reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
return reg_
[docs]
def get_reg(
self,
reg_metric: str,
lamb_l1: float,
lamb_entropy: float,
lamb_coef: float,
lamb_coefdiff: float,
):
"""
Get regularization from the model.
Adapted from "pykan".
Args
----
reg_metric : str
Regularization metric.
'edge_forward_dr_n', 'edge_forward_dr_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
lamb_l1 : float
L1 Regularization parameter
lamb_entropy : float
Entropy Regularization parameter
lamb_coef : float
Coefficient Regularization parameter
lamb_coefdiff : float
Coefficient Smoothness Regularization parameter
Returns
-------
torch.Tensor
"""
return self._reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
[docs]
def attribute(self, l=None, i=None, out_score=None, plot=True):
"""
Get attribution scores
Adapted from "pykan".
Args
----
l : None | int
layer index
i : None | int
neuron index
out_score : None | torch.Tensor
specify output scores
plot : bool
when plot = True, display the bar show
Returns
-------
torch.Tensor
attribution scores
"""
if not self.save_act:
warnings.warn(
"Activations are not saved, cannot get attribution scores",
RuntimeWarning,
)
return None
if l is not None:
self.attribute()
out_score = self.node_scores[l]
node_scores = []
subnode_scores = []
edge_scores = []
l_query = l
if l is None:
l_end = self.depth
else:
l_end = l
# back propagate from the queried layer
out_dim = self.width[l_end]
if out_score is None:
node_score = torch.eye(out_dim).requires_grad_(True)
else:
node_score = torch.diag(out_score).requires_grad_(True)
node_scores.append(node_score)
for l in range(l_end, 0, -1):
subnode_score = node_score[:, : self.width[l]]
subnode_scores.append(subnode_score)
# subnode to edge
edge_score = torch.einsum(
"oi,ko,i->koi",
self.edge_actscale[l - 1],
subnode_score.to(self.device),
1 / (self.subnode_actscale[l - 1] + 1e-4),
)
edge_scores.append(edge_score)
# edge to node
node_score = torch.sum(edge_score, dim=1)
node_scores.append(node_score)
self.node_scores_all = list(reversed(node_scores))
self.edge_scores_all = list(reversed(edge_scores))
self.subnode_scores_all = list(reversed(subnode_scores))
self.node_scores = [torch.mean(l, dim=0) for l in self.node_scores_all]
self.edge_scores = [torch.mean(l, dim=0) for l in self.edge_scores_all]
self.subnode_scores = [torch.mean(l, dim=0) for l in self.subnode_scores_all]
# return: (out_dim, in_dim)
if l_query is not None:
if i is None:
return self.node_scores_all[0]
else:
# plot
if plot:
in_dim = self.width[0]
plt.figure(figsize=(1 * in_dim, 3))
plt.bar(
range(in_dim), self.node_scores_all[0][i].cpu().detach().numpy()
)
plt.xticks(range(in_dim))
return self.node_scores_all[0][i]
[docs]
def node_attribute(self):
"""
Get node attribution scores.
Adapted from "pykan".
"""
self.node_attribute_scores = []
for l in range(1, self.depth + 1):
node_attr = self.attribute(l)
self.node_attribute_scores.append(node_attr)
[docs]
def train_(
self,
dataset,
optimizer=None,
closure=None,
scheduler=None,
steps: int = 10,
log: int = 1,
loss_fn=None,
batch=-1,
lamb=0.0,
lamb_l1=1.0,
lamb_entropy=2.0,
lamb_coef=0.0,
lamb_coefdiff=0.0,
reg_metric="edge_forward_dr_n",
verbose=False,
):
"""
Train the model
Args
----
dataset : dict
Dictionary containing train_input, train_label, test_input, test_label
optimizer : torch.optim.Optimizer | None
Optimizer to use, default: None
closure : Callable | None
Closure function for optimizer, default: None
scheduler : torch.optim.lr_scheduler | None
Scheduler to use, default: None
steps : int
Number of steps, default: 10
log : int
Logging frequency, default: 1
loss_fn : torch.nn.Module | Callable |None
Loss function to use, default: None
batch : int
batch size, if -1 then full., default: -1
lamb : float
L1 Regularization parameter. If 0, no regularization.
lamb_l1 : float
L1 Regularization parameter
lamb_entropy : float
Entropy Regularization parameter
lamb_coef : float
Coefficient Regularization parameter
lamb_coefdiff : float
Coefficient Smoothness Regularization parameter
reg_metric : str
Regularization metric.
'edge_forward_dr_n', 'edge_forward_dr_u', 'edge_forward_sum', 'edge_backward', 'node_backward'
verbose : bool
Verbose mode, default: False
Returns
-------
dict
Dictionary containing train_loss and test_loss
"""
if lamb > 0.0 and not self.save_act:
lamb = 0.0
warnings.warn(
"Regularization is not supported without saving activations",
RuntimeWarning,
)
pbar = tqdm(range(steps), ncols=100)
if loss_fn is None:
loss_fn = loss_fn_eval = torch.nn.MSELoss()
else:
loss_fn = loss_fn_eval = loss_fn
if optimizer is None:
optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)
else:
optimizer = optimizer
results: dict = {}
results["train_loss"] = []
results["test_loss"] = []
results["reg"] = []
if batch == -1 or batch > dataset["train_input"].shape[0]:
batch_size = dataset["train_input"].shape[0]
batch_size_test = dataset["test_input"].shape[0]
else:
batch_size = batch
batch_size_test = batch
def _closure():
nonlocal train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset["train_input"][train_id].to(self.device))
train_loss = loss_fn(pred, dataset["train_label"][train_id].to(self.device))
if self.save_act:
if reg_metric == "edge_backward":
self.attribute()
if reg_metric == "node_backward":
self.node_attribute()
reg_ = self.get_reg(
reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff
)
else:
reg_ = torch.tensor(0.0, device=self.device)
objective = train_loss + lamb * reg_
objective.backward()
return objective
if closure is None and isinstance(optimizer, torch.optim.LBFGS):
closure = _closure
for _ in pbar:
self.train()
train_id = np.random.choice(
dataset["train_input"].shape[0], batch_size, replace=False
)
test_id = np.random.choice(
dataset["test_input"].shape[0], batch_size_test, replace=False
)
if isinstance(optimizer, torch.optim.LBFGS):
optimizer.step(closure)
else:
optimizer.zero_grad()
pred = self.forward(dataset["train_input"][train_id].to(self.device))
train_loss = loss_fn(
pred, dataset["train_label"][train_id].to(self.device)
)
if self.save_act:
if reg_metric == "edge_backward":
self.attribute()
if reg_metric == "node_backward":
self.node_attribute()
reg_ = self.get_reg(
reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff
)
else:
reg_ = torch.tensor(0.0, device=self.device)
loss = train_loss + lamb * reg_
optimizer.zero_grad()
loss.backward()
optimizer.step(closure)
self.eval()
test_loss = loss_fn_eval(
self.forward(dataset["test_input"][test_id].to(self.device)),
dataset["test_label"][test_id].to(self.device),
)
if scheduler is not None:
scheduler.step(test_loss)
if _ % log == 0:
pbar.set_postfix(
{
"train loss": train_loss.cpu().detach().numpy(),
"test loss": test_loss.cpu().detach().numpy(),
}
)
results["train_loss"].append(train_loss.cpu().detach().numpy())
results["test_loss"].append(test_loss.cpu().detach().numpy())
results["reg"].append(reg_.cpu().detach().numpy())
return results
[docs]
def plot(
self,
x0=None,
sampling=1000,
from_acts=False,
scale=0.5,
beta=3,
metric="forward_n",
mask=False,
in_vars=None,
out_vars=None,
title=None,
):
"""
Plot the model.
Adapted from "pykan".
Arguments
---------
x0 : torch.Tensor | None
Input tensor to plot, if None, plot from saved activations
sampling : int
Sampling frequency
from_acts : bool
Plot from saved activations
scale : float
Scale of the plot
beta : float
Beta value
metric : str
Metric to use. 'forward_n', 'forward_u', 'backward'
in_vars : list[int] | None
Input variables to plot
out_vars : list[int] | None
Output variables to plot
title : str | None
Title of the plot
"""
if self.is_map:
warnings.warn("Not supported for map layer", RuntimeWarning)
return None
if self.is_batchnorm:
warnings.warn("Not supported for batchnorm layer", RuntimeWarning)
return None
if x0 is None and not from_acts:
warnings.warn(
"x0 is not provided, try plot from saved activations.", RuntimeWarning
)
from_acts = True
if from_acts and not self.acts:
warnings.warn(
"Activations are not saved, cannot plot from activations",
RuntimeWarning,
)
return None
if mask and not hasattr(self, "mask"):
warnings.warn(
"Make sure to run model.prune_node() first to compute mask. Continue without mask.",
RuntimeWarning,
)
mask = False
if not os.path.exists("./figures"):
os.makedirs("./figures")
if metric == "backward":
self.attribute()
save_act = self.save_act
self.save_act = False
self.eval()
for idx, qkan_layer in enumerate(self.layers):
assert isinstance(qkan_layer, QKANLayer)
if idx == 0:
x = x0
else:
ymin = torch.min(ynew.cpu().detach(), dim=0).values # noqa: F821
ymax = torch.max(ynew.cpu().detach(), dim=0).values # noqa: F821
x = torch.stack(
[
torch.linspace(
ymin[i],
ymax[i],
steps=sampling,
device=self.device,
)
for i in range(qkan_layer.in_dim)
]
).permute(1, 0) # x.shape = (sampling, in_dim)
if from_acts:
x = self.acts[idx]
y = qkan_layer.forward_no_sum(x).transpose(
1, 2
) # y.shape = (sampling, in_dim, out_dim)
for i in range(self.width[idx]):
for j in range(self.width[idx + 1]):
fig, ax = plt.subplots(figsize=(2, 2))
plt.xticks([])
plt.yticks([])
plt.gca().patch.set_edgecolor("black")
plt.gca().patch.set_linewidth(1.5)
plt.scatter(
x[:, i].detach().cpu().numpy(),
y[:, i, j].detach().cpu().numpy(),
color="black",
s=40,
)
plt.gca().spines[:].set_color("black")
plt.savefig(
f"./figures/dr_{idx}_{i}_{j}.png", bbox_inches="tight", dpi=400
)
plt.close()
with torch.no_grad():
ynew = qkan_layer.forward(x) # noqa: F841
def score2alpha(score):
return np.tanh(beta * score)
alpha = []
try:
if save_act and metric is not None:
if metric == "forward_n":
scores = self.acts_scale
elif metric == "forward_u":
scores = self.edge_actscale
elif metric == "backward":
scores = self.edge_scores
else:
raise RuntimeError(f"metric = '{metric}' cannot be recognized")
alpha = [score2alpha(score.cpu().detach().numpy()) for score in scores]
except RuntimeError:
warnings.warn(f"metric = '{metric}' cannot be recognized", RuntimeWarning)
finally:
if not alpha:
alpha = [
torch.ones(layer.out_dim, layer.in_dim).detach().numpy()
for layer in self.layers
]
# draw skeleton
width = np.array(self.width)
A = 1
y0 = 0.4
neuron_depth = len(width)
min_spacing = A / np.maximum(np.max(width), 5)
# max_neuron = np.max(width)
max_num_weights = np.max(width[:-1] * width[1:])
y1 = 0.4 / np.maximum(max_num_weights, 3)
fig, ax = plt.subplots(
figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0)
)
# plot scatters and lines
for l in range(neuron_depth):
n = width[l]
# spacing = A / n
for i in range(n):
plt.scatter(
1 / (2 * n) + i / n,
l * y0,
s=min_spacing**2 * 10000 * scale**2,
color="black",
)
if l < neuron_depth - 1:
# plot connections
n_next = width[l + 1]
N = n * n_next
for j in range(n_next):
id_ = i * n_next + j
if mask:
plt.plot(
[1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
[l * y0, (l + 1 / 2) * y0 - y1],
color="black",
lw=2 * scale,
alpha=alpha[l][j][i]
* self.mask[l][i].item()
* self.mask[l + 1][j].item(),
)
plt.plot(
[1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
[(l + 1 / 2) * y0 + y1, (l + 1) * y0],
color="black",
lw=2 * scale,
alpha=alpha[l][j][i]
* self.mask[l][i].item()
* self.mask[l + 1][j].item(),
)
else:
plt.plot(
[1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N],
[l * y0, (l + 1 / 2) * y0 - y1],
color="black",
lw=2 * scale,
alpha=alpha[l][j][i],
)
plt.plot(
[1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next],
[(l + 1 / 2) * y0 + y1, (l + 1) * y0],
color="black",
lw=2 * scale,
alpha=alpha[l][j][i],
)
plt.xlim(0, 1)
plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0)
# -- Transformation functions
DC_to_FC = ax.transData.transform
FC_to_NFC = fig.transFigure.inverted().transform
# -- Take data coordinates and transform them to normalized figure coordinates
DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
plt.axis("off")
# plot splines
for l in range(neuron_depth - 1):
n = width[l]
for i in range(n):
n_next = width[l + 1]
N = n * n_next
for j in range(n_next):
id_ = i * n_next + j
im = plt.imread(f"./figures/dr_{l}_{i}_{j}.png")
left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1]
up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1]
newax = fig.add_axes([left, bottom, right - left, up - bottom])
if mask:
newax.imshow(
im,
alpha=alpha[l][j][i]
* self.mask[l][i].item()
* self.mask[l + 1][j].item(),
)
else:
newax.imshow(im, alpha=alpha[l][j][i])
newax.axis("off")
if in_vars is not None:
n = self.width[0]
for i in range(n):
plt.gcf().get_axes()[0].text(
1 / (2 * (n)) + i / (n),
-0.1,
in_vars[i],
fontsize=40 * scale,
horizontalalignment="center",
verticalalignment="center",
)
if out_vars is not None:
n = self.width[-1]
for i in range(n):
plt.gcf().get_axes()[0].text(
1 / (2 * (n)) + i / (n),
y0 * (len(self.width) - 1) + 0.1,
out_vars[i],
fontsize=40 * scale,
horizontalalignment="center",
verticalalignment="center",
)
if title is not None:
plt.gcf().get_axes()[0].text(
0.5,
y0 * (len(self.width) - 1) + 0.2,
title,
fontsize=40 * scale,
horizontalalignment="center",
verticalalignment="center",
)
self.save_act = save_act
[docs]
def prune_node(
self,
threshold: float = 1e-2,
mode: str = "auto",
active_neurons_id: Optional[list] = None,
):
"""
Pruning nodes.
Adapted from "pykan".
Args
----
threshold : float
if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
mode : str
"auto" or "manual". with "auto", nodes are automatically pruned using threshold.
With "manual", active_neurons_id should be passed in.
Returns
-------
QKAN
pruned network
"""
if not hasattr(self, "acts"):
warnings.warn("No activations, cannot prune nodes", RuntimeWarning)
return None
if mode == "manual" and active_neurons_id is None:
warnings.warn(
"active_neurons_id is not provided. Continue with auto mode.",
RuntimeWarning,
)
mode = "auto"
mask = [
torch.ones(
self.width[0],
)
]
active_neurons = [list(range(self.width[0]))]
for i in range(len(self.acts_scale) - 1):
if mode == "auto":
in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold
out_important = torch.max(self.acts_scale[i + 1], dim=0)[0] > threshold
overall_important = in_important * out_important
elif mode == "manual":
assert active_neurons_id is not None
overall_important = torch.zeros(self.width[i + 1], dtype=torch.bool)
overall_important[active_neurons_id[i + 1]] = True
mask.append(overall_important.float())
active_neurons.append(
torch.where(overall_important == True)[0].tolist() # noqa: E712
)
active_neurons.append(list(range(self.width[-1])))
mask.append(
torch.ones(
self.width[-1],
)
)
self.mask = mask # for plot
for l in range(len(self.acts_scale) - 1):
for i in range(self.width[l + 1]):
if i not in active_neurons[l + 1]:
self.remove_node(l + 1, i)
model2 = QKAN(
deepcopy(self.width),
reps=self.reps,
is_map=self.is_map,
is_batchnorm=self.is_batchnorm,
hidden=self.hidden,
device=self.device,
solver=self.solver,
qml_device=self.qml_device,
ansatz=self.ansatz,
norm_out=self.norm_out,
preact_trainable=self.preact_trainable,
postact_weight_trainable=self.postact_weight_trainable,
postact_bias_trainable=self.postact_bias_trainable,
base_activation=self.base_activation,
ba_trainable=self.ba_trainable,
save_act=self.save_act,
seed=self.seed,
)
model2.load_state_dict(self.state_dict())
for i, layer in enumerate(model2.layers):
if not isinstance(layer, QKANLayer):
continue
model2.layers[i] = layer.get_subset(
active_neurons[i], active_neurons[i + 1]
)
model2.width[i] = len(active_neurons[i])
model2.cache_data = self.cache_data
return model2
[docs]
def prune_edge(self, threshold: float = 3e-2):
"""
Pruning edges.
Adapted from "pykan".
Args:
threshold: float
if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
"""
if not hasattr(self, "acts"):
warnings.warn("No activations, cannot prune edges", RuntimeWarning)
return None
for i in range(len(self.width) - 1):
old_mask = self.layers[i].mask.data
self.layers[i].mask.data = (
(self.edge_scores[i] > threshold) * old_mask
).float()
[docs]
def prune(self, node_th: float = 1e-2, edge_th: float = 3e-2):
"""
Prune (both nodes and edges).
Adapted from "pykan".
Args
----
node_th : float
if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
edge_th : float
if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
Returns
-------
QKAN
pruned network
"""
if not hasattr(self, "acts"):
warnings.warn("No activations, cannot prune.", RuntimeWarning)
return None
self = self.prune_node(node_th)
self.forward(self.cache_data)
self.attribute()
self.prune_edge(edge_th)
return self
[docs]
def remove_edge(self, layer_idx, in_idx, out_idx):
"""
Remove activtion phi(layer_idx, in_idx, out_idx) (set its mask to zero)
Args
----
layer_idx : int
Layer index
in_idx : int
Input node index
out_idx : int
Output node index
"""
if not isinstance(self.layers[layer_idx], QKAN):
return
self.layers[layer_idx].mask[out_idx, in_idx] = 0.0
[docs]
def remove_node(self, layer_idx, in_idx, mode="all"):
"""
remove neuron (layer_idx, in_idx) (set the masks of all incoming and outgoing activation functions to zero)
Args
----
layer_idx : int
Layer index
in_idx : int
Input node index
mode : str
Mode to remove. "all" or "up" or "down", default: "all"
"""
if mode == "down":
if not isinstance(self.layers[layer_idx - 1], QKAN):
return
self.layers[layer_idx - 1].mask[in_idx, :] = 0.0
elif mode == "up":
if not isinstance(self.layers[layer_idx], QKAN):
return
self.layers[layer_idx].mask[:, in_idx] = 0.0
else:
self.remove_node(layer_idx, in_idx, mode="up")
self.remove_node(layer_idx, in_idx, mode="down")
[docs]
@staticmethod
def clear_ckpts(folder="./model_ckpt"):
"""
Clear all checkpoints.
Args
----
folder : str
Folder containing checkpoints, default: "./model_ckpt"
"""
if os.path.exists(folder):
files = glob(folder + "/*")
for f in files:
os.remove(f)
else:
os.makedirs(folder)
[docs]
def save_ckpt(self, name, folder="./model_ckpt"):
"""
Save the current model as checkpoint.
Args
----
name : str
Name of the checkpoint
folder : str
Folder to save the checkpoint, default: "./model_ckpt"
"""
if not os.path.exists(folder):
os.makedirs(folder)
torch.save(self.state_dict(), folder + "/" + name)
print("save this model to", folder + "/" + name)
[docs]
def load_ckpt(self, name, folder="./model_ckpt"):
"""
Load a checkpoint to the current model.
Args
----
name : str
Name of the checkpoint
folder : str
Folder containing the checkpoint, default: "./model_ckpt"
"""
self.load_state_dict(torch.load(folder + "/" + name))