# Copyright (c) 2026, Jiun-Cheng Jiang. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
QKAN solver for real quantum device execution via Qiskit Runtime.
"""
import math
from typing import Optional
import torch
try:
from qiskit import QuantumCircuit # type: ignore
from qiskit.quantum_info import SparsePauliOp # type: ignore
from qiskit.transpiler.preset_passmanagers import ( # type: ignore
generate_preset_pass_manager,
)
_QISKIT_AVAILABLE = True
except ImportError:
_QISKIT_AVAILABLE = False
try:
from qiskit_ibm_runtime import EstimatorV2 as Estimator # type: ignore
_QISKIT_RUNTIME_AVAILABLE = True
except ImportError:
_QISKIT_RUNTIME_AVAILABLE = False
try:
from qiskit.primitives import StatevectorEstimator # type: ignore
_SV_ESTIMATOR_AVAILABLE = True
except ImportError:
_SV_ESTIMATOR_AVAILABLE = False
try:
from qiskit_aer import AerSimulator # type: ignore
_AER_AVAILABLE = True
except ImportError:
_AER_AVAILABLE = False
from ._mitigation import _apply_mitigation
# ---------------------------------------------------------------------------
# Qiskit circuit builders
# ---------------------------------------------------------------------------
def _fold_qiskit_circuit(qc: "QuantumCircuit", scale_factor: int) -> "QuantumCircuit":
"""Apply gate folding to a Qiskit circuit for ZNE.
Produces U . (U_dag . U)^((scale_factor-1)/2) which has the same unitary
as U but with scale_factor x the gate count (and thus noise).
"""
if scale_factor <= 1:
return qc
folded = qc.copy()
for _ in range((scale_factor - 1) // 2):
folded = folded.compose(qc.inverse()).compose(qc)
return folded
def _build_qiskit_pz_circuit(
x_val: float,
theta_vals: list[float],
reps: int,
encoded_x_vals: Optional[list[float]] = None,
) -> "QuantumCircuit":
"""
Build pz_encoding circuit: H -> [RZ(θ₀) RY(θ₁) RZ(x)]^reps -> RZ(θ_f₀) RY(θ_f₁)
theta_vals layout: [θ₀₀, θ₀₁, θ₁₀, θ₁₁, ..., θ_f₀, θ_f₁] (2 params per layer + 2 final)
"""
qc = QuantumCircuit(1)
qc.h(0)
for l in range(reps):
qc.rz(theta_vals[2 * l], 0)
qc.ry(theta_vals[2 * l + 1], 0)
enc = encoded_x_vals[l] if encoded_x_vals is not None else x_val
qc.rz(enc, 0)
qc.rz(theta_vals[2 * reps], 0)
qc.ry(theta_vals[2 * reps + 1], 0)
return qc
def _build_qiskit_rpz_circuit(
encoded_x_vals: list[float],
theta_vals: list[float],
reps: int,
) -> "QuantumCircuit":
"""
Build rpz_encoding circuit: H -> [RY(θ) RZ(encoded_x)]^reps -> RY(θ_final)
theta_vals layout: [θ₀, θ₁, ..., θ_final] (1 param per layer + 1 final)
"""
qc = QuantumCircuit(1)
qc.h(0)
for l in range(reps):
qc.ry(theta_vals[l], 0)
qc.rz(encoded_x_vals[l], 0)
qc.ry(theta_vals[reps], 0)
return qc
def _build_qiskit_real_circuit(
x_val: float,
theta_vals: list[float],
reps: int,
encoded_x_vals: Optional[list[float]] = None,
) -> "QuantumCircuit":
"""
Build real ansatz circuit: H -> [X RY(θ) Z RY(x)]^reps
theta_vals layout: [θ₀, θ₁, ...] (1 param per layer, no final gate)
"""
qc = QuantumCircuit(1)
qc.h(0)
for l in range(reps):
qc.x(0)
qc.ry(theta_vals[l], 0)
qc.z(0)
enc = encoded_x_vals[l] if encoded_x_vals is not None else x_val
qc.ry(enc, 0)
return qc
# ---------------------------------------------------------------------------
# Parallel multi-qubit packing (Qiskit)
# ---------------------------------------------------------------------------
def _build_qiskit_parallel_circuit(
single_circuits: list["QuantumCircuit"],
) -> "QuantumCircuit":
"""
Pack N independent single-qubit circuits into one N-qubit circuit.
Each single-qubit circuit is applied to a separate qubit, enabling
parallel execution on a multi-qubit QPU.
"""
n = len(single_circuits)
qc = QuantumCircuit(n)
for qubit_idx, sc in enumerate(single_circuits):
for instruction in sc.data:
gate = instruction.operation
params = gate.params
name = gate.name
if name == "h":
qc.h(qubit_idx)
elif name == "x":
qc.x(qubit_idx)
elif name == "z":
qc.z(qubit_idx)
elif name == "rx":
qc.rx(params[0], qubit_idx)
elif name == "ry":
qc.ry(params[0], qubit_idx)
elif name == "rz":
qc.rz(params[0], qubit_idx)
else:
raise ValueError(f"Unsupported gate '{name}' in parallel packing")
return qc
def _make_parallel_observables(n_qubits: int) -> list["SparsePauliOp"]:
"""
Create Z observables for each qubit in an N-qubit circuit.
Returns a list of N SparsePauliOp, each measuring Z on one qubit.
Qiskit uses little-endian ordering: qubit 0 is the rightmost character.
E.g. for 3 qubits: [IIZ, IZI, ZII] for qubits 0, 1, 2 respectively.
"""
observables = []
for k in range(n_qubits):
# Qiskit little-endian: qubit k is at string position (n-1-k) from the left
pauli_str = "I" * (n_qubits - 1 - k) + "Z" + "I" * k
observables.append(SparsePauliOp.from_list([(pauli_str, 1.0)]))
return observables
# ---------------------------------------------------------------------------
# Qiskit solver
# ---------------------------------------------------------------------------
class _QiskitParamShift(torch.autograd.Function):
"""Autograd function using parameter-shift rule for Qiskit circuits."""
@staticmethod
def forward(ctx, x, theta, preacts_w, preacts_b, reps, config):
ctx.save_for_backward(x, theta, preacts_w, preacts_b)
ctx.reps = reps
ctx.config = config
return _qiskit_evaluate(x, theta, preacts_w, preacts_b, reps, config)
@staticmethod
def backward(ctx, grad_output):
x, theta, preacts_w, preacts_b = ctx.saved_tensors
reps = ctx.reps
config = ctx.config
shift = math.pi / 2
# Gradient w.r.t. theta via parameter-shift rule
grad_theta = torch.zeros_like(theta)
flat_theta = theta.reshape(-1)
for k in range(flat_theta.numel()):
theta_plus = flat_theta.clone()
theta_plus[k] += shift
theta_minus = flat_theta.clone()
theta_minus[k] -= shift
f_plus = _qiskit_evaluate(
x, theta_plus.reshape(theta.shape), preacts_w, preacts_b, reps, config
)
f_minus = _qiskit_evaluate(
x, theta_minus.reshape(theta.shape), preacts_w, preacts_b, reps, config
)
grad_k = (f_plus - f_minus) / (2 * math.sin(shift))
grad_theta.reshape(-1)[k] = (grad_output * grad_k).sum()
# Gradient w.r.t. preacts_weight
grad_pw = None
if preacts_w.requires_grad:
grad_pw = torch.zeros_like(preacts_w)
flat_pw = preacts_w.reshape(-1)
for k in range(flat_pw.numel()):
pw_plus = flat_pw.clone()
pw_plus[k] += shift
pw_minus = flat_pw.clone()
pw_minus[k] -= shift
f_plus = _qiskit_evaluate(
x, theta, pw_plus.reshape(preacts_w.shape), preacts_b, reps, config
)
f_minus = _qiskit_evaluate(
x, theta, pw_minus.reshape(preacts_w.shape), preacts_b, reps, config
)
grad_pw.reshape(-1)[k] = (
grad_output * (f_plus - f_minus) / (2 * math.sin(shift))
).sum()
# Gradient w.r.t. preacts_bias
grad_pb = None
if preacts_b.requires_grad:
grad_pb = torch.zeros_like(preacts_b)
flat_pb = preacts_b.reshape(-1)
for k in range(flat_pb.numel()):
pb_plus = flat_pb.clone()
pb_plus[k] += shift
pb_minus = flat_pb.clone()
pb_minus[k] -= shift
f_plus = _qiskit_evaluate(
x, theta, preacts_w, pb_plus.reshape(preacts_b.shape), reps, config
)
f_minus = _qiskit_evaluate(
x, theta, preacts_w, pb_minus.reshape(preacts_b.shape), reps, config
)
grad_pb.reshape(-1)[k] = (
grad_output * (f_plus - f_minus) / (2 * math.sin(shift))
).sum()
return None, grad_theta, grad_pw, grad_pb, None, None
def _probe_max_pubs(est, probe_pubs, max_pubs):
"""
Binary-search for the largest PUB batch the QPU accepts.
Submits `probe_pubs[:max_pubs]` synchronously. On memory error (6073),
halves and retries until a working size is found. Returns (result, max_pubs)
where result is the successful job result for the probe batch.
"""
while max_pubs >= 1:
batch = probe_pubs[:max_pubs]
try:
job = est.run(batch)
result = job.result()
return result, max_pubs
except Exception as e:
err_str = str(e)
if "6073" in err_str or "memory" in err_str.lower():
old_max = max_pubs
max_pubs = max(1, max_pubs // 2)
if max_pubs == old_max:
raise # can't go smaller than 1
print(
f" [qsolver] Job memory limit hit at {old_max} PUBs/job, "
f"trying {max_pubs}"
)
else:
raise
raise RuntimeError("Could not find a working PUB batch size")
def _submit_and_collect(est, all_pubs, all_chunk_sizes, max_pubs):
"""
Submit PUBs with the largest batch size the QPU can handle.
1. Probes with max_pubs (all PUBs if 0) synchronously to find the
largest accepted batch size via binary search on memory errors.
2. Submits all remaining batches asynchronously for max throughput.
3. Collects results in order.
Returns (expvals, actual_max_pubs) so callers can cache the working size.
"""
n_total = len(all_pubs)
if max_pubs <= 0:
max_pubs = n_total
expvals = [None] * n_total
# Step 1: Probe with first batch to discover working max_pubs
first_batch_size = min(max_pubs, n_total)
first_batch = all_pubs[:first_batch_size]
probe_result, max_pubs = _probe_max_pubs(est, first_batch, first_batch_size)
# Collect probe results (first max_pubs PUBs)
probed_count = min(max_pubs, n_total)
for i in range(probed_count):
evs = probe_result[i].data.evs
expvals[i] = [float(v) for v in evs]
# Step 2: Submit remaining batches asynchronously
remaining_start = probed_count
if remaining_start < n_total:
jobs = []
job_ranges = []
for batch_start in range(remaining_start, n_total, max_pubs):
batch_end = min(batch_start + max_pubs, n_total)
job_pubs = all_pubs[batch_start:batch_end]
jobs.append(est.run(job_pubs))
job_ranges.append((batch_start, batch_end))
n_jobs = len(jobs)
print(f" [qsolver] Submitting {n_jobs} async job(s), {max_pubs} PUBs/job")
# Collect all async results
for job, (batch_start, batch_end) in zip(jobs, job_ranges):
result = job.result()
for i, global_idx in enumerate(range(batch_start, batch_end)):
evs = result[i].data.evs
expvals[global_idx] = [float(v) for v in evs]
# Flatten
flat = []
for ev_list in expvals:
flat.extend(ev_list)
return flat, max_pubs
# Module-level cache for the discovered max PUBs per backend
_MAX_PUBS_CACHE: dict = {}
def _qiskit_run_parallel(
circuits,
n_qubits,
estimator,
backend,
optimization_level,
shots,
max_pubs_per_job=0,
resilience_level=None,
twirling=None,
):
"""
Pack single-qubit circuits into multi-qubit batches and submit async.
Groups `circuits` into chunks of `n_qubits`, packs each chunk into one
multi-qubit circuit. Jobs are submitted asynchronously with automatic
PUB batch sizing:
- If `max_pubs_per_job` > 0, uses that as the initial batch size.
- If `max_pubs_per_job` == 0 (default), starts with all PUBs in one job.
- On memory error (6073), automatically halves and retries.
- The discovered working batch size is cached per backend.
"""
total = len(circuits)
# Build all PUBs first
all_pubs = []
all_chunk_sizes = []
if estimator is not None:
for start in range(0, total, n_qubits):
batch_circuits = circuits[start : start + n_qubits]
chunk_size = len(batch_circuits)
all_chunk_sizes.append(chunk_size)
packed_qc = _build_qiskit_parallel_circuit(batch_circuits)
chunk_obs = _make_parallel_observables(chunk_size)
all_pubs.append((packed_qc, chunk_obs))
initial_max = max_pubs_per_job if max_pubs_per_job > 0 else len(all_pubs)
expvals, _ = _submit_and_collect(
estimator, all_pubs, all_chunk_sizes, initial_max
)
return expvals
elif backend is not None:
pm = generate_preset_pass_manager(
backend=backend, optimization_level=optimization_level
)
rt_estimator = Estimator(mode=backend)
if shots is not None:
rt_estimator.options.default_shots = shots
if resilience_level is not None:
rt_estimator.options.resilience_level = resilience_level
if twirling is not None:
if twirling.get("enable_gates"):
rt_estimator.options.twirling.enable_gates = True
if twirling.get("enable_measure"):
rt_estimator.options.twirling.enable_measure = True
if twirling.get("num_randomizations") is not None:
rt_estimator.options.twirling.num_randomizations = twirling[
"num_randomizations"
]
for start in range(0, total, n_qubits):
batch_circuits = circuits[start : start + n_qubits]
chunk_size = len(batch_circuits)
all_chunk_sizes.append(chunk_size)
packed_qc = _build_qiskit_parallel_circuit(batch_circuits)
isa_qc = pm.run(packed_qc)
chunk_obs = _make_parallel_observables(chunk_size)
isa_obs = [obs.apply_layout(isa_qc.layout) for obs in chunk_obs]
all_pubs.append((isa_qc, isa_obs))
# Use cached max or start with all PUBs
cache_key = getattr(backend, "name", str(backend))
initial_max = (
max_pubs_per_job
if max_pubs_per_job > 0
else _MAX_PUBS_CACHE.get(cache_key, len(all_pubs))
)
expvals, actual_max = _submit_and_collect(
rt_estimator, all_pubs, all_chunk_sizes, initial_max
)
_MAX_PUBS_CACHE[cache_key] = actual_max
return expvals
return []
def _qiskit_evaluate(
x: torch.Tensor,
theta: torch.Tensor,
preacts_weight: torch.Tensor,
preacts_bias: torch.Tensor,
reps: int,
config: dict,
) -> torch.Tensor:
"""
Evaluate all circuits on the Qiskit backend and return expectation values.
Returns shape: (batch_size, out_dim, in_dim)
"""
batch, in_dim = x.shape
ansatz = config["ansatz"]
preacts_trainable = config["preacts_trainable"]
out_dim = config["out_dim"]
backend = config.get("backend", None)
estimator = config.get("estimator", None)
shots = config["shots"]
optimization_level = config.get("optimization_level", 1)
parallel_qubits = config.get("parallel_qubits", None)
# Broadcast theta/preacts to (out_dim, in_dim, ...)
if len(theta.shape) != 4:
theta = theta.unsqueeze(0)
if theta.shape[1] != in_dim:
repeat_out = out_dim
repeat_in = in_dim // theta.shape[1] + 1
theta = theta.repeat(repeat_out, repeat_in, 1, 1)[:, :in_dim, :, :]
_needs_encoded_x = preacts_trainable or ansatz in ("rpz_encoding", "rpz")
encoded_x = None
if _needs_encoded_x:
if len(preacts_weight.shape) != 3:
preacts_weight = preacts_weight.unsqueeze(0)
preacts_bias = preacts_bias.unsqueeze(0)
if preacts_weight.shape[1] != in_dim:
repeat_out = out_dim
repeat_in = in_dim // preacts_weight.shape[1] + 1
preacts_weight = preacts_weight.repeat(repeat_out, repeat_in, 1)[
:, :in_dim, :
]
preacts_bias = preacts_bias.repeat(repeat_out, repeat_in, 1)[:, :in_dim, :]
encoded_x = torch.einsum("oir,bi->boir", preacts_weight, x).add(preacts_bias)
# Move to CPU for circuit parameter extraction
x_np = x.detach().cpu()
theta_np = theta.detach().cpu()
encoded_x_np = encoded_x.detach().cpu() if encoded_x is not None else None
# Build circuits and observables
circuits = []
observables = []
pauli_z = SparsePauliOp.from_list([("Z", 1.0)])
for b in range(batch):
for o in range(out_dim):
for i in range(in_dim):
if ansatz in ("pz_encoding", "pz"):
t = theta_np[o, i].reshape(-1).tolist()
enc_vals = None
if encoded_x_np is not None:
enc_vals = encoded_x_np[b, o, i].tolist()
qc = _build_qiskit_pz_circuit(float(x_np[b, i]), t, reps, enc_vals)
elif ansatz in ("rpz_encoding", "rpz"):
t = theta_np[o, i].reshape(-1).tolist()
enc_vals = encoded_x_np[b, o, i].tolist() # type: ignore
qc = _build_qiskit_rpz_circuit(enc_vals, t, reps)
elif ansatz == "real":
t = theta_np[o, i].reshape(-1).tolist()
enc_vals = None
if encoded_x_np is not None:
enc_vals = encoded_x_np[b, o, i].tolist()
qc = _build_qiskit_real_circuit(
float(x_np[b, i]), t, reps, enc_vals
)
else:
raise NotImplementedError(
f"Ansatz '{ansatz}' not supported by qiskit_solver"
)
circuits.append(qc)
observables.append(pauli_z)
# Execute via the appropriate Estimator
max_pubs = config.get("max_pubs_per_job", 0)
mitigation = config.get("mitigation", {})
def _run_qiskit(scale_factor=1):
run_circuits = (
[_fold_qiskit_circuit(qc, scale_factor) for qc in circuits]
if scale_factor > 1
else circuits
)
if parallel_qubits and parallel_qubits > 1:
return _qiskit_run_parallel(
run_circuits,
parallel_qubits,
estimator,
backend,
optimization_level,
shots,
max_pubs_per_job=max_pubs,
resilience_level=config.get("resilience_level"),
twirling=config.get("twirling"),
)
elif estimator is not None:
pubs = list(zip(run_circuits, observables))
job = estimator.run(pubs)
result = job.result()
return [float(r.data.evs) for r in result]
elif backend is not None:
pm = generate_preset_pass_manager(
backend=backend, optimization_level=optimization_level
)
isa_circuits = pm.run(run_circuits)
isa_observables = [
obs.apply_layout(qc.layout)
for obs, qc in zip(observables, isa_circuits)
]
rt_estimator = Estimator(mode=backend)
if shots is not None:
rt_estimator.options.default_shots = shots
rl = config.get("resilience_level")
tw = config.get("twirling")
if rl is not None:
rt_estimator.options.resilience_level = rl
if tw is not None:
if tw.get("enable_gates"):
rt_estimator.options.twirling.enable_gates = True
if tw.get("enable_measure"):
rt_estimator.options.twirling.enable_measure = True
pubs = list(zip(isa_circuits, isa_observables))
job = rt_estimator.run(pubs)
result = job.result()
return [float(r.data.evs) for r in result]
else:
raise ValueError("No estimator or backend provided.")
if mitigation:
expvals = _apply_mitigation(_run_qiskit, mitigation)
else:
expvals = _run_qiskit(1)
output = torch.tensor(expvals, dtype=x.dtype, device=x.device)
return output.reshape(batch, out_dim, in_dim)
[docs]
def qiskit_solver(
x: torch.Tensor,
theta: torch.Tensor,
preacts_weight: torch.Tensor,
preacts_bias: torch.Tensor,
reps: int,
**kwargs,
) -> torch.Tensor:
"""
Execute QKAN circuits on IBM Quantum backends via Qiskit Runtime.
Drop-in replacement for torch_exact_solver. Circuits are built to match
the exact gate sequences of each ansatz, then executed on the specified
backend using Qiskit's Estimator primitive.
Supports training via the parameter-shift rule when gradients are needed.
Args
----
x : torch.Tensor
shape: (batch_size, in_dim)
theta : torch.Tensor
shape: (\\*group, reps+1, n_params) or (\\*group, reps, 1) for real
preacts_weight : torch.Tensor
shape: (\\*group, reps)
preacts_bias : torch.Tensor
shape: (\\*group, reps)
reps : int
ansatz : str
"pz_encoding", "pz", "rpz_encoding", "rpz", or "real"
preacts_trainable : bool
out_dim : int
backend : qiskit Backend
Qiskit backend instance (e.g., AerSimulator(), or from QiskitRuntimeService)
shots : int, optional
Number of shots per circuit. None for exact expectation (statevector).
optimization_level : int
Transpiler optimization level (0-3), default: 1
Returns
-------
torch.Tensor
shape: (batch_size, out_dim, in_dim)
"""
if not _QISKIT_AVAILABLE:
raise ImportError(
"Qiskit is required for qiskit_solver. "
"Install with: pip install qiskit qiskit-ibm-runtime"
)
ansatz = kwargs.get("ansatz", "pz_encoding")
preacts_trainable = kwargs.get("preacts_trainable", False)
out_dim = kwargs.get("out_dim", x.shape[1])
shots = kwargs.get("shots", None)
optimization_level = kwargs.get("optimization_level", 1)
parallel_qubits = kwargs.get("parallel_qubits", None)
backend = kwargs.get("backend", None)
estimator = kwargs.get("estimator", None)
# Resolve execution mode: estimator > backend > StatevectorEstimator > AerSimulator
if estimator is None and backend is None:
if _SV_ESTIMATOR_AVAILABLE:
estimator = StatevectorEstimator()
elif _AER_AVAILABLE:
backend = AerSimulator(method="statevector")
else:
raise ValueError(
"No backend or estimator specified. Install qiskit >= 1.0 "
"(for StatevectorEstimator), qiskit-aer, or qiskit-ibm-runtime."
)
# Auto-detect QPU size from backend if parallel_qubits="auto"
if parallel_qubits == "auto" and backend is not None:
parallel_qubits = backend.num_qubits
max_pubs_per_job = kwargs.get("max_pubs_per_job", 0)
config = {
"ansatz": ansatz,
"preacts_trainable": preacts_trainable,
"out_dim": out_dim,
"backend": backend,
"estimator": estimator,
"shots": shots,
"optimization_level": optimization_level,
"parallel_qubits": parallel_qubits,
"max_pubs_per_job": max_pubs_per_job,
"resilience_level": kwargs.get("resilience_level", None),
"twirling": kwargs.get("twirling", None),
"mitigation": kwargs.get("mitigation", {}),
}
needs_grad = theta.requires_grad or x.requires_grad
if preacts_trainable:
needs_grad = (
needs_grad or preacts_weight.requires_grad or preacts_bias.requires_grad
)
if needs_grad:
return _QiskitParamShift.apply(
x, theta, preacts_weight, preacts_bias, reps, config
)
else:
return _qiskit_evaluate(x, theta, preacts_weight, preacts_bias, reps, config)