Source code for qkan.solver.cudaq_solver

# mypy: ignore-errors
# Copyright (c) 2026, Jiun-Cheng Jiang. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
QKAN solver for CUDA-Q GPU-accelerated simulation or QPU execution.
"""

import math

import torch

try:
    import cudaq  # type: ignore

    _CUDAQ_AVAILABLE = True
except ImportError:
    _CUDAQ_AVAILABLE = False


from ._mitigation import _apply_mitigation

# ---------------------------------------------------------------------------
# CUDA-Q gate-folded kernel builders (for ZNE)
# ---------------------------------------------------------------------------


def _build_cudaq_pz_folded_kernel(reps: int, scale_factor: int):
    """Build a gate-folded pz_encoding kernel for ZNE: U . (U_dag . U)^n."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(x_val: float, thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            rz(thetas[2 * l], q)
            ry(thetas[2 * l + 1], q)
            rz(x_val, q)
        rz(thetas[2 * reps], q)
        ry(thetas[2 * reps + 1], q)
        for _f in range(n_folds):
            ry(-thetas[2 * reps + 1], q)
            rz(-thetas[2 * reps], q)
            for l in range(reps - 1, -1, -1):
                rz(-x_val, q)
                ry(-thetas[2 * l + 1], q)
                rz(-thetas[2 * l], q)
            h(q)
            h(q)
            for l in range(reps):
                rz(thetas[2 * l], q)
                ry(thetas[2 * l + 1], q)
                rz(x_val, q)
            rz(thetas[2 * reps], q)
            ry(thetas[2 * reps + 1], q)

    return kernel


def _build_cudaq_pz_preact_folded_kernel(reps: int, scale_factor: int):
    """Build a gate-folded pz_encoding preact kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            rz(thetas[2 * l], q)
            ry(thetas[2 * l + 1], q)
            rz(encoded_x[l], q)
        rz(thetas[2 * reps], q)
        ry(thetas[2 * reps + 1], q)
        for _f in range(n_folds):
            ry(-thetas[2 * reps + 1], q)
            rz(-thetas[2 * reps], q)
            for l in range(reps - 1, -1, -1):
                rz(-encoded_x[l], q)
                ry(-thetas[2 * l + 1], q)
                rz(-thetas[2 * l], q)
            h(q)
            h(q)
            for l in range(reps):
                rz(thetas[2 * l], q)
                ry(thetas[2 * l + 1], q)
                rz(encoded_x[l], q)
            rz(thetas[2 * reps], q)
            ry(thetas[2 * reps + 1], q)

    return kernel


def _build_cudaq_rpz_folded_kernel(reps: int, scale_factor: int):
    """Build a gate-folded rpz_encoding kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            ry(thetas[l], q)
            rz(encoded_x[l], q)
        ry(thetas[reps], q)
        for _f in range(n_folds):
            ry(-thetas[reps], q)
            for l in range(reps - 1, -1, -1):
                rz(-encoded_x[l], q)
                ry(-thetas[l], q)
            h(q)
            h(q)
            for l in range(reps):
                ry(thetas[l], q)
                rz(encoded_x[l], q)
            ry(thetas[reps], q)

    return kernel


def _build_cudaq_real_folded_kernel(reps: int, scale_factor: int):
    """Build a gate-folded real ansatz kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(x_val: float, thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            x(q)
            ry(thetas[l], q)
            z(q)
            ry(x_val, q)
        for _f in range(n_folds):
            for l in range(reps - 1, -1, -1):
                ry(-x_val, q)
                z(q)
                ry(-thetas[l], q)
                x(q)
            h(q)
            h(q)
            for l in range(reps):
                x(q)
                ry(thetas[l], q)
                z(q)
                ry(x_val, q)

    return kernel


def _build_cudaq_real_preact_folded_kernel(reps: int, scale_factor: int):
    """Build a gate-folded real preact kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            x(q)
            ry(thetas[l], q)
            z(q)
            ry(encoded_x[l], q)
        for _f in range(n_folds):
            for l in range(reps - 1, -1, -1):
                ry(-encoded_x[l], q)
                z(q)
                ry(-thetas[l], q)
                x(q)
            h(q)
            h(q)
            for l in range(reps):
                x(q)
                ry(thetas[l], q)
                z(q)
                ry(encoded_x[l], q)

    return kernel


def _build_cudaq_parallel_pz_folded_kernel(n_qubits: int, reps: int, scale_factor: int):
    """Build a gate-folded parallel pz kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(x_vals: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        params_per = 2 * (reps + 1)
        for q_idx in range(n_qubits):
            offset = q_idx * params_per
            h(qubits[q_idx])
            for l in range(reps):
                rz(all_thetas[offset + 2 * l], qubits[q_idx])
                ry(all_thetas[offset + 2 * l + 1], qubits[q_idx])
                rz(x_vals[q_idx], qubits[q_idx])
            rz(all_thetas[offset + 2 * reps], qubits[q_idx])
            ry(all_thetas[offset + 2 * reps + 1], qubits[q_idx])
            for _f in range(n_folds):
                ry(-all_thetas[offset + 2 * reps + 1], qubits[q_idx])
                rz(-all_thetas[offset + 2 * reps], qubits[q_idx])
                for l in range(reps - 1, -1, -1):
                    rz(-x_vals[q_idx], qubits[q_idx])
                    ry(-all_thetas[offset + 2 * l + 1], qubits[q_idx])
                    rz(-all_thetas[offset + 2 * l], qubits[q_idx])
                h(qubits[q_idx])
                h(qubits[q_idx])
                for l in range(reps):
                    rz(all_thetas[offset + 2 * l], qubits[q_idx])
                    ry(all_thetas[offset + 2 * l + 1], qubits[q_idx])
                    rz(x_vals[q_idx], qubits[q_idx])
                rz(all_thetas[offset + 2 * reps], qubits[q_idx])
                ry(all_thetas[offset + 2 * reps + 1], qubits[q_idx])

    return kernel


def _build_cudaq_parallel_real_folded_kernel(
    n_qubits: int, reps: int, scale_factor: int
):
    """Build a gate-folded parallel real kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(x_vals: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        for q_idx in range(n_qubits):
            offset = q_idx * reps
            h(qubits[q_idx])
            for l in range(reps):
                x(qubits[q_idx])
                ry(all_thetas[offset + l], qubits[q_idx])
                z(qubits[q_idx])
                ry(x_vals[q_idx], qubits[q_idx])
            for _f in range(n_folds):
                for l in range(reps - 1, -1, -1):
                    ry(-x_vals[q_idx], qubits[q_idx])
                    z(qubits[q_idx])
                    ry(-all_thetas[offset + l], qubits[q_idx])
                    x(qubits[q_idx])
                h(qubits[q_idx])
                h(qubits[q_idx])
                for l in range(reps):
                    x(qubits[q_idx])
                    ry(all_thetas[offset + l], qubits[q_idx])
                    z(qubits[q_idx])
                    ry(x_vals[q_idx], qubits[q_idx])

    return kernel


def _build_cudaq_parallel_rpz_folded_kernel(
    n_qubits: int, reps: int, scale_factor: int
):
    """Build a gate-folded parallel rpz kernel for ZNE."""
    n_folds = (scale_factor - 1) // 2

    @cudaq.kernel
    def kernel(encoded_xs: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        t_per = reps + 1
        for q_idx in range(n_qubits):
            t_off = q_idx * t_per
            x_off = q_idx * reps
            h(qubits[q_idx])
            for l in range(reps):
                ry(all_thetas[t_off + l], qubits[q_idx])
                rz(encoded_xs[x_off + l], qubits[q_idx])
            ry(all_thetas[t_off + reps], qubits[q_idx])
            for _f in range(n_folds):
                ry(-all_thetas[t_off + reps], qubits[q_idx])
                for l in range(reps - 1, -1, -1):
                    rz(-encoded_xs[x_off + l], qubits[q_idx])
                    ry(-all_thetas[t_off + l], qubits[q_idx])
                h(qubits[q_idx])
                h(qubits[q_idx])
                for l in range(reps):
                    ry(all_thetas[t_off + l], qubits[q_idx])
                    rz(encoded_xs[x_off + l], qubits[q_idx])
                ry(all_thetas[t_off + reps], qubits[q_idx])

    return kernel


# ---------------------------------------------------------------------------
# CUDA-Q solver
# ---------------------------------------------------------------------------


def _build_cudaq_pz_kernel(reps: int):
    """Build a CUDA-Q kernel for pz_encoding ansatz."""

    @cudaq.kernel
    def kernel(x_val: float, thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            rz(thetas[2 * l], q)
            ry(thetas[2 * l + 1], q)
            rz(x_val, q)
        rz(thetas[2 * reps], q)
        ry(thetas[2 * reps + 1], q)

    return kernel


def _build_cudaq_pz_preact_kernel(reps: int):
    """Build a CUDA-Q kernel for pz_encoding with trainable preactivation."""

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            rz(thetas[2 * l], q)
            ry(thetas[2 * l + 1], q)
            rz(encoded_x[l], q)
        rz(thetas[2 * reps], q)
        ry(thetas[2 * reps + 1], q)

    return kernel


def _build_cudaq_rpz_kernel(reps: int):
    """Build a CUDA-Q kernel for rpz_encoding ansatz."""

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            ry(thetas[l], q)
            rz(encoded_x[l], q)
        ry(thetas[reps], q)

    return kernel


def _build_cudaq_real_kernel(reps: int):
    """Build a CUDA-Q kernel for real ansatz."""

    @cudaq.kernel
    def kernel(x_val: float, thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            x(q)
            ry(thetas[l], q)
            z(q)
            ry(x_val, q)

    return kernel


def _build_cudaq_real_preact_kernel(reps: int):
    """Build a CUDA-Q kernel for real ansatz with trainable preactivation."""

    @cudaq.kernel
    def kernel(encoded_x: list[float], thetas: list[float]):
        q = cudaq.qubit()
        h(q)
        for l in range(reps):
            x(q)
            ry(thetas[l], q)
            z(q)
            ry(encoded_x[l], q)

    return kernel


# Cache: {(ansatz, reps, preacts_trainable, scale_factor): kernel}
_CUDAQ_KERNEL_CACHE: dict = {}


def _get_cudaq_kernel(
    ansatz: str, reps: int, preacts_trainable: bool, scale_factor: int = 1
):
    """Get or build a cached CUDA-Q kernel (optionally gate-folded for ZNE)."""
    key = (ansatz, reps, preacts_trainable, scale_factor)
    if key not in _CUDAQ_KERNEL_CACHE:
        sf = scale_factor
        if ansatz in ("pz_encoding", "pz"):
            if preacts_trainable:
                _CUDAQ_KERNEL_CACHE[key] = (
                    _build_cudaq_pz_preact_folded_kernel(reps, sf)
                    if sf > 1
                    else _build_cudaq_pz_preact_kernel(reps)
                )
            else:
                _CUDAQ_KERNEL_CACHE[key] = (
                    _build_cudaq_pz_folded_kernel(reps, sf)
                    if sf > 1
                    else _build_cudaq_pz_kernel(reps)
                )
        elif ansatz in ("rpz_encoding", "rpz"):
            _CUDAQ_KERNEL_CACHE[key] = (
                _build_cudaq_rpz_folded_kernel(reps, sf)
                if sf > 1
                else _build_cudaq_rpz_kernel(reps)
            )
        elif ansatz == "real":
            if preacts_trainable:
                _CUDAQ_KERNEL_CACHE[key] = (
                    _build_cudaq_real_preact_folded_kernel(reps, sf)
                    if sf > 1
                    else _build_cudaq_real_preact_kernel(reps)
                )
            else:
                _CUDAQ_KERNEL_CACHE[key] = (
                    _build_cudaq_real_folded_kernel(reps, sf)
                    if sf > 1
                    else _build_cudaq_real_kernel(reps)
                )
        else:
            raise NotImplementedError(
                f"Ansatz '{ansatz}' not supported by cudaq_solver"
            )
    return _CUDAQ_KERNEL_CACHE[key]


def _build_cudaq_parallel_pz_kernel(n_qubits: int, reps: int):
    """Build a CUDA-Q kernel that runs N independent pz circuits in parallel."""

    @cudaq.kernel
    def kernel(x_vals: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        params_per = 2 * (reps + 1)
        for q_idx in range(n_qubits):
            h(qubits[q_idx])
            offset = q_idx * params_per
            for l in range(reps):
                rz(all_thetas[offset + 2 * l], qubits[q_idx])
                ry(all_thetas[offset + 2 * l + 1], qubits[q_idx])
                rz(x_vals[q_idx], qubits[q_idx])
            rz(all_thetas[offset + 2 * reps], qubits[q_idx])
            ry(all_thetas[offset + 2 * reps + 1], qubits[q_idx])

    return kernel


def _build_cudaq_parallel_real_kernel(n_qubits: int, reps: int):
    """Build a CUDA-Q kernel that runs N independent real circuits in parallel."""

    @cudaq.kernel
    def kernel(x_vals: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        for q_idx in range(n_qubits):
            h(qubits[q_idx])
            offset = q_idx * reps
            for l in range(reps):
                x(qubits[q_idx])
                ry(all_thetas[offset + l], qubits[q_idx])
                z(qubits[q_idx])
                ry(x_vals[q_idx], qubits[q_idx])

    return kernel


def _build_cudaq_parallel_rpz_kernel(n_qubits: int, reps: int):
    """Build a CUDA-Q kernel that runs N independent rpz circuits in parallel."""

    @cudaq.kernel
    def kernel(encoded_xs: list[float], all_thetas: list[float]):
        qubits = cudaq.qvector(n_qubits)
        t_per = reps + 1
        for q_idx in range(n_qubits):
            h(qubits[q_idx])
            t_off = q_idx * t_per
            x_off = q_idx * reps
            for l in range(reps):
                ry(all_thetas[t_off + l], qubits[q_idx])
                rz(encoded_xs[x_off + l], qubits[q_idx])
            ry(all_thetas[t_off + reps], qubits[q_idx])

    return kernel


def _cudaq_run_parallel(
    all_args,
    ansatz,
    reps,
    preacts_trainable,
    n_qubits,
    shots_count,
    scale_factor=1,
):
    """
    Pack independent single-qubit circuits onto an N-qubit QPU.

    Runs ceil(total / n_qubits) multi-qubit jobs instead of `total` single-qubit jobs.
    """
    total = len(all_args)
    expvals = []

    for start in range(0, total, n_qubits):
        chunk = all_args[start : start + n_qubits]
        chunk_size = len(chunk)

        # Flatten args into parallel kernel format
        if ansatz in ("pz_encoding", "pz") and not preacts_trainable:
            x_vals = [a[0] for a in chunk]
            all_thetas = []
            for a in chunk:
                all_thetas.extend(a[1])
            # Pad if chunk < n_qubits
            actual_n = chunk_size
            if actual_n < n_qubits:
                x_vals.extend([0.0] * (n_qubits - actual_n))
                pad_thetas = [0.0] * (2 * (reps + 1))
                for _ in range(n_qubits - actual_n):
                    all_thetas.extend(pad_thetas)
                actual_n = n_qubits

            par_kernel = (
                _build_cudaq_parallel_pz_folded_kernel(actual_n, reps, scale_factor)
                if scale_factor > 1
                else _build_cudaq_parallel_pz_kernel(actual_n, reps)
            )
            args = (x_vals, all_thetas)

        elif ansatz == "real" and not preacts_trainable:
            x_vals = [a[0] for a in chunk]
            all_thetas = []
            for a in chunk:
                all_thetas.extend(a[1])
            actual_n = chunk_size
            if actual_n < n_qubits:
                x_vals.extend([0.0] * (n_qubits - actual_n))
                for _ in range(n_qubits - actual_n):
                    all_thetas.extend([0.0] * reps)
                actual_n = n_qubits

            par_kernel = (
                _build_cudaq_parallel_real_folded_kernel(actual_n, reps, scale_factor)
                if scale_factor > 1
                else _build_cudaq_parallel_real_kernel(actual_n, reps)
            )
            args = (x_vals, all_thetas)

        elif ansatz in ("rpz_encoding", "rpz") or preacts_trainable:
            encoded_xs = []
            all_thetas = []
            for a in chunk:
                enc, t = a
                if isinstance(enc, list):
                    encoded_xs.extend(enc)
                else:
                    encoded_xs.extend([enc] * reps)
                all_thetas.extend(t)
            actual_n = chunk_size
            if actual_n < n_qubits:
                for _ in range(n_qubits - actual_n):
                    encoded_xs.extend([0.0] * reps)
                    all_thetas.extend([0.0] * (reps + 1))
                actual_n = n_qubits

            par_kernel = (
                _build_cudaq_parallel_rpz_folded_kernel(actual_n, reps, scale_factor)
                if scale_factor > 1
                else _build_cudaq_parallel_rpz_kernel(actual_n, reps)
            )
            args = (encoded_xs, all_thetas)
        else:
            raise NotImplementedError(f"Parallel not supported for ansatz '{ansatz}'")

        # Single observe call with Z0 + Z1 + ... + Z_{N-1} Hamiltonian,
        # then extract per-qubit <Z_k> from the result
        spin_sum = cudaq.spin.z(0)
        for q_idx in range(1, actual_n):
            spin_sum += cudaq.spin.z(q_idx)

        if shots_count is not None:
            result = cudaq.observe(par_kernel, spin_sum, *args, shots_count=shots_count)
        else:
            result = cudaq.observe(par_kernel, spin_sum, *args)

        for q_idx in range(chunk_size):
            expvals.append(result.expectation(cudaq.spin.z(q_idx)))

    return expvals


def _cudaq_evaluate(
    x: torch.Tensor,
    theta: torch.Tensor,
    preacts_weight: torch.Tensor,
    preacts_bias: torch.Tensor,
    reps: int,
    config: dict,
) -> torch.Tensor:
    """
    Evaluate all circuits on CUDA-Q and return expectation values.

    Returns shape: (batch_size, out_dim, in_dim)
    """
    batch, in_dim = x.shape
    ansatz = config["ansatz"]
    preacts_trainable = config["preacts_trainable"]
    out_dim = config["out_dim"]
    shots_count = config["shots"]
    parallel_qubits = config.get("parallel_qubits", None)

    # Broadcasting (same as other solvers)
    if len(theta.shape) != 4:
        theta = theta.unsqueeze(0)
    if theta.shape[1] != in_dim:
        repeat_out = out_dim
        repeat_in = in_dim // theta.shape[1] + 1
        theta = theta.repeat(repeat_out, repeat_in, 1, 1)[:, :in_dim, :, :]

    _needs_encoded_x = preacts_trainable or ansatz in ("rpz_encoding", "rpz")
    encoded_x = None
    if _needs_encoded_x:
        if len(preacts_weight.shape) != 3:
            preacts_weight = preacts_weight.unsqueeze(0)
            preacts_bias = preacts_bias.unsqueeze(0)
        if preacts_weight.shape[1] != in_dim:
            repeat_out = out_dim
            repeat_in = in_dim // preacts_weight.shape[1] + 1
            preacts_weight = preacts_weight.repeat(repeat_out, repeat_in, 1)[
                :, :in_dim, :
            ]
            preacts_bias = preacts_bias.repeat(repeat_out, repeat_in, 1)[:, :in_dim, :]
        encoded_x = torch.einsum("oir,bi->boir", preacts_weight, x).add(preacts_bias)

    x_np = x.detach().cpu()
    theta_np = theta.detach().cpu()
    encoded_x_np = encoded_x.detach().cpu() if encoded_x is not None else None

    # Collect all circuit args first
    all_args = []
    for b in range(batch):
        for o in range(out_dim):
            for i in range(in_dim):
                t = theta_np[o, i].reshape(-1).tolist()

                if ansatz in ("pz_encoding", "pz"):
                    if preacts_trainable:
                        all_args.append((encoded_x_np[b, o, i].tolist(), t))
                    else:
                        all_args.append((float(x_np[b, i]), t))
                elif ansatz in ("rpz_encoding", "rpz"):
                    all_args.append((encoded_x_np[b, o, i].tolist(), t))
                elif ansatz == "real":
                    if preacts_trainable:
                        all_args.append((encoded_x_np[b, o, i].tolist(), t))
                    else:
                        all_args.append((float(x_np[b, i]), t))
                else:
                    raise NotImplementedError

    mitigation = config.get("mitigation", {})

    def _run_cudaq(scale_factor=1):
        if parallel_qubits and parallel_qubits > 1:
            return _cudaq_run_parallel(
                all_args,
                ansatz,
                reps,
                preacts_trainable,
                parallel_qubits,
                shots_count,
                scale_factor=scale_factor,
            )
        else:
            kernel = _get_cudaq_kernel(ansatz, reps, preacts_trainable, scale_factor)
            spin_z = cudaq.spin.z(0)
            ev = []
            for args in all_args:
                if shots_count is not None:
                    result = cudaq.observe(
                        kernel, spin_z, *args, shots_count=shots_count
                    )
                else:
                    result = cudaq.observe(kernel, spin_z, *args)
                ev.append(result.expectation())
            return ev

    if mitigation:
        expvals = _apply_mitigation(_run_cudaq, mitigation)
    else:
        expvals = _run_cudaq(1)

    output = torch.tensor(expvals, dtype=x.dtype, device=x.device)
    return output.reshape(batch, out_dim, in_dim)


class _CudaqParamShift(torch.autograd.Function):
    """Autograd function using parameter-shift rule for CUDA-Q circuits."""

    @staticmethod
    def forward(ctx, x, theta, preacts_w, preacts_b, reps, config):
        ctx.save_for_backward(x, theta, preacts_w, preacts_b)
        ctx.reps = reps
        ctx.config = config
        return _cudaq_evaluate(x, theta, preacts_w, preacts_b, reps, config)

    @staticmethod
    def backward(ctx, grad_output):
        x, theta, preacts_w, preacts_b = ctx.saved_tensors
        reps = ctx.reps
        config = ctx.config
        shift = math.pi / 2

        grad_theta = torch.zeros_like(theta)
        flat_theta = theta.reshape(-1)
        for k in range(flat_theta.numel()):
            theta_plus = flat_theta.clone()
            theta_plus[k] += shift
            theta_minus = flat_theta.clone()
            theta_minus[k] -= shift

            f_plus = _cudaq_evaluate(
                x, theta_plus.reshape(theta.shape), preacts_w, preacts_b, reps, config
            )
            f_minus = _cudaq_evaluate(
                x, theta_minus.reshape(theta.shape), preacts_w, preacts_b, reps, config
            )
            grad_theta.reshape(-1)[k] = (
                grad_output * (f_plus - f_minus) / (2 * math.sin(shift))
            ).sum()

        grad_pw = None
        if preacts_w.requires_grad:
            grad_pw = torch.zeros_like(preacts_w)
            flat_pw = preacts_w.reshape(-1)
            for k in range(flat_pw.numel()):
                pw_plus = flat_pw.clone()
                pw_plus[k] += shift
                pw_minus = flat_pw.clone()
                pw_minus[k] -= shift
                f_plus = _cudaq_evaluate(
                    x, theta, pw_plus.reshape(preacts_w.shape), preacts_b, reps, config
                )
                f_minus = _cudaq_evaluate(
                    x, theta, pw_minus.reshape(preacts_w.shape), preacts_b, reps, config
                )
                grad_pw.reshape(-1)[k] = (
                    grad_output * (f_plus - f_minus) / (2 * math.sin(shift))
                ).sum()

        grad_pb = None
        if preacts_b.requires_grad:
            grad_pb = torch.zeros_like(preacts_b)
            flat_pb = preacts_b.reshape(-1)
            for k in range(flat_pb.numel()):
                pb_plus = flat_pb.clone()
                pb_plus[k] += shift
                pb_minus = flat_pb.clone()
                pb_minus[k] -= shift
                f_plus = _cudaq_evaluate(
                    x, theta, preacts_w, pb_plus.reshape(preacts_b.shape), reps, config
                )
                f_minus = _cudaq_evaluate(
                    x, theta, preacts_w, pb_minus.reshape(preacts_b.shape), reps, config
                )
                grad_pb.reshape(-1)[k] = (
                    grad_output * (f_plus - f_minus) / (2 * math.sin(shift))
                ).sum()

        return None, grad_theta, grad_pw, grad_pb, None, None


[docs] def cudaq_solver( x: torch.Tensor, theta: torch.Tensor, preacts_weight: torch.Tensor, preacts_bias: torch.Tensor, reps: int, **kwargs, ) -> torch.Tensor: """ Execute QKAN circuits via NVIDIA CUDA-Q. Drop-in replacement for torch_exact_solver using CUDA-Q's GPU-accelerated quantum simulation or QPU backends. Circuits are built as CUDA-Q kernels and expectation values are computed via cudaq.observe(). Supports training via the parameter-shift rule when gradients are needed. Args ---- x : torch.Tensor shape: (batch_size, in_dim) theta : torch.Tensor shape: (\\*group, reps+1, n_params) or (\\*group, reps, 1) for real preacts_weight : torch.Tensor shape: (\\*group, reps) preacts_bias : torch.Tensor shape: (\\*group, reps) reps : int ansatz : str "pz_encoding", "pz", "rpz_encoding", "rpz", or "real" preacts_trainable : bool out_dim : int target : str, optional CUDA-Q target (e.g., "nvidia", "nvidia-mqpu", "qpp-cpu"). Set before calling via cudaq.set_target(). shots : int, optional Number of shots. None for exact statevector expectation. Returns ------- torch.Tensor shape: (batch_size, out_dim, in_dim) """ if not _CUDAQ_AVAILABLE: raise ImportError( "CUDA-Q is required for cudaq_solver. " "Install from: https://nvidia.github.io/cuda-quantum/" ) ansatz = kwargs.get("ansatz", "pz_encoding") preacts_trainable = kwargs.get("preacts_trainable", False) out_dim = kwargs.get("out_dim", x.shape[1]) shots = kwargs.get("shots", None) target = kwargs.get("target", None) machine = kwargs.get("machine", None) parallel_qubits = kwargs.get("parallel_qubits", None) if target is not None: target_kwargs = {} if machine is not None: target_kwargs["machine"] = machine cudaq.set_target(target, **target_kwargs) config = { "ansatz": ansatz, "preacts_trainable": preacts_trainable, "out_dim": out_dim, "shots": shots, "target": target, "parallel_qubits": parallel_qubits, "mitigation": kwargs.get("mitigation", {}), } needs_grad = theta.requires_grad or x.requires_grad if preacts_trainable: needs_grad = ( needs_grad or preacts_weight.requires_grad or preacts_bias.requires_grad ) if needs_grad: return _CudaqParamShift.apply( x, theta, preacts_weight, preacts_bias, reps, config ) else: return _cudaq_evaluate(x, theta, preacts_weight, preacts_bias, reps, config)