Source code for qkan.daruan.daruan

# Copyright (c) 2025, Jiun-Cheng Jiang. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import Literal

import numpy as np
import torch
import torch.nn as nn

from .torch_qc import StateVector, TorchGates


def torch_exact_solver(
    x: torch.Tensor,
    theta: torch.Tensor,
    preacts_weight: torch.Tensor,
    preacts_bias: torch.Tensor,
    reps: int,
    **kwargs,
) -> torch.Tensor:
    """
    x: torch.Tensor, shape: (batch_size, dim)
    theta: torch.Tensor, shape: (dim, reps, 2)
    reps: int
    device: str, default: "cpu"
    no_sum: bool, default: False
    ansatz: str, options: ["pz_encoding", "px_encoding"], default: "pz_encoding"
    """
    device = kwargs.get("device", "cpu")
    ansatz = kwargs.get("ansatz", "pz_encoding")

    encoded_x = [
        torch.einsum("i,bi->bi", preacts_weight[:, l], x).add(preacts_bias[:, l])
        for l in range(reps)
    ]

    def pz_encoding(encoded_x: list[torch.Tensor], theta: torch.Tensor):
        """
        x: torch.Tensor, shape: (batch_size, dim)
        theta: torch.Tensor, shape: (dim, reps, 2)
        """
        psi = StateVector(
            x.shape[0],
            theta.shape[0],
            device=device,
        )  # psi.state: torch.Tensor, shape: (batch_size, dim, 2)
        psi.h()
        for l in range(reps):
            psi.rz(theta[:, l, 0])
            psi.ry(theta[:, l, 1])
            psi.state = torch.einsum(
                "mnbi,bin->bim",
                TorchGates.rz_gate(encoded_x[l]),
                psi.state,
            )

        psi.rz(theta[:, reps, 0])
        psi.ry(theta[:, reps, 1])
        return psi.measure_z()  # shape: (batch_size, dim)

    if ansatz == "pz_encoding":
        circuit = pz_encoding
    else:
        raise NotImplementedError()
    return circuit(encoded_x, theta)  # shape: (batch_size, dim)


[docs] class DARUAN(nn.Module): def __init__( self, dim: int, reps: int, device="cpu", solver: Literal["qml", "exact"] = "exact", ansatz="pz_encoding", preact_trainable: bool = False, postact_weight_trainable: bool = False, postact_bias_trainable: bool = False, seed=0, ): super(DARUAN, self).__init__() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) self.dim = dim self.reps = reps self.device = device self.solver: Literal["qml", "exact"] = solver self.ansatz = ansatz if ansatz == "pz_encoding": self.theta = nn.Parameter( nn.init.xavier_normal_(torch.empty(dim, reps + 1, 2, device=device)) ) elif ansatz == "px_encoding": self.theta = nn.Parameter( nn.init.xavier_normal_(torch.empty(dim, reps + 1, 1, device=device)) ) else: raise NotImplementedError() self.preacts_trainable = preact_trainable self.preacts_weight = nn.Parameter( torch.ones(dim, reps, device=device), requires_grad=preact_trainable, ) self.preacts_bias = nn.Parameter( torch.zeros(dim, reps, device=device), requires_grad=preact_trainable, ) self.postact_weight_trainable = postact_weight_trainable self.postact_weights = nn.Parameter( torch.ones(dim, device=device), requires_grad=postact_weight_trainable, ) self.postact_bias_trainable = postact_bias_trainable self.postact_bias = nn.Parameter( torch.zeros(dim, device=device), requires_grad=postact_bias_trainable, )
[docs] def forward(self, x: torch.Tensor): assert x.shape[1] == self.dim, "Invalid input dimension" x = x.to(self.device) if self.solver == "qml": raise NotImplementedError elif self.solver == "exact": postacts = torch_exact_solver( x, self.theta, self.preacts_weight, self.preacts_bias, self.reps, device=self.device, ansatz=self.ansatz, ) else: assert False, "Invalid solver" x = torch.einsum("bi,i->bi", postacts, self.postact_weights).add( self.postact_bias ) return x