"""
QKAN layer simulating solver
This module provides a solver for quantum neural networks using PyTorch or PennyLane
Code author: Jiun-Cheng Jiang (Jim137@GitHub)
Contact: [jcjiang@phys.ntu.edu.tw](mailto:jcjiang@phys.ntu.edu.tw)
"""
import numpy as np
import torch
from .torch_qc import StateVector, TorchGates
[docs]
def qml_solver(x: torch.Tensor, theta: torch.Tensor, reps: int, **kwargs):
"""
Single-qubit data reuploading circuit using PennyLane.
Args
----
x : torch.Tensor
shape: (batch_size, in_dim)
theta : torch.Tensor
shape: (reps, 2)
reps : int
qml_device : str
default: "default.qubit"
"""
import pennylane as qml # type: ignore
qml_device: str = kwargs.get("qml_device", "default.qubit")
dev = qml.device(qml_device, wires=1)
@qml.qnode(dev, interface="torch")
def circuit(x: torch.Tensor, theta: torch.Tensor):
"""
Args
----
x : torch.Tensor
shape: (batch_size, in_dim)
theta : torch.Tensor
shape: (reps, 2)
"""
qml.RY(np.pi / 2, wires=0)
for l in range(reps):
qml.RZ(theta[l, 0], wires=0)
qml.RY(theta[l, 1], wires=0)
qml.RZ(x, wires=0)
qml.RZ(theta[reps, 0], wires=0)
qml.RY(theta[reps, 1], wires=0)
return qml.expval(qml.PauliZ(0))
return circuit(x, theta)
[docs]
def torch_exact_solver(
x: torch.Tensor,
theta: torch.Tensor,
preacts_weight: torch.Tensor,
preacts_bias: torch.Tensor,
reps: int,
**kwargs,
) -> torch.Tensor:
"""
Single-qubit data reuploading circuit.
Args
----
x : torch.Tensor
shape: (batch_size, in_dim)
theta : torch.Tensor
preacts_weight : torch.Tensor
shape: (out_dim, in_dim, reps)
preacts_bias : torch.Tensor
shape: (out_dim, in_dim, reps)
reps : int
ansatz : str
options: ["pz_encoding", "px_encoding"], default: "pz_encoding"
n_group : int
number of neurons in a group, default: in_dim of x
Returns
-------
torch.Tensor
shape: (batch_size, out_dim, in_dim)
"""
batch, in_dim = x.shape
device = x.device
ansatz = kwargs.get("ansatz", "pz_encoding")
# group = kwargs.get("group", in_dim)
preacts_trainable = kwargs.get("preacts_trainable", False)
fast_measure = kwargs.get("fast_measure", True)
out_dim = preacts_weight.shape[0]
if preacts_trainable:
preacts_trainable = True
encoded_x = [
torch.einsum("oi,bi->boi", preacts_weight[:, :, l], x).add(
preacts_bias[:, :, l]
)
for l in range(reps)
] # len: reps, shape: (batch_size, out_dim, in_dim)
if len(theta.shape) != 4:
theta = theta.unsqueeze(0)
if theta.shape[1] != in_dim:
repeat_out = out_dim
repeat_in = in_dim // theta.shape[1] + 1
theta = theta.repeat(repeat_out, repeat_in, 1, 1)[:, :in_dim, :, :]
def pz_encoding(theta: torch.Tensor):
"""
Args
----
theta : torch.Tensor
shape: (out_dim, n_group, reps, 2)
"""
psi = StateVector(
x.shape[0],
theta.shape[0],
theta.shape[1],
device=device,
) # psi.state: torch.Tensor, shape: (batch_size, out_dim, in_dim, 2)
psi.h()
if not preacts_trainable:
rug = TorchGates.rz_gate(x)
for l in range(reps):
psi.rz(theta[:, :, l, 0])
psi.ry(theta[:, :, l, 1])
if not preacts_trainable:
psi.state = torch.einsum("mnbi,boin->boim", rug, psi.state)
else:
psi.state = torch.einsum(
"mnboi,boin->boim",
TorchGates.rz_gate(encoded_x[l]),
psi.state,
)
psi.rz(theta[:, :, reps, 0])
psi.ry(theta[:, :, reps, 1])
return psi.measure_z(fast_measure) # shape: (batch_size, out_dim, in_dim)
def rpz_enocding(theta: torch.Tensor):
"""
Args
----
theta : torch.Tensor
shape: (out_dim, n_group, reps, 2)
"""
psi = StateVector(
x.shape[0],
theta.shape[0],
theta.shape[1],
device=device,
)
psi.h()
for l in range(reps):
psi.ry(theta[:, :, l, 0])
psi.state = torch.einsum(
"mnboi,boin->boim",
TorchGates.rz_gate(encoded_x[l]),
psi.state,
)
psi.ry(theta[:, :, reps, 0])
return psi.measure_z(fast_measure) # shape: (batch_size, out_dim, in_dim)
def px_encoding(theta: torch.Tensor):
"""
Args
----
theta: torch.Tensor
shape: (out_dim, n_group, reps, 1)
"""
psi = StateVector(
x.shape[0],
theta.shape[0],
theta.shape[1],
device=device,
) # psi.state: torch.Tensor, shape: (batch_size * g, out_dim, n_group, 2)
psi.h()
for l in range(reps):
psi.rz(theta[:, :, l, 0])
psi.state = torch.einsum(
"mnboi,boin->boim",
TorchGates.rx_gate(
torch.acos(
# torch.sin(
encoded_x[l]
# )
# add sin to prevent input from exceeding pm 1
)
),
psi.state,
)
"""
# complex extension implementation
psi.state = torch.einsum(
"mnboi,boin->boim",
TorchGates.acrx_gate(
torch.einsum("oi,bi->boi", preacts_weight[:, :, l], x)
),
psi.state,
)
"""
psi.rz(theta[:, :, reps, 0])
return psi.measure_z(fast_measure) # shape: (batch_size, out_dim, in_dim)
if ansatz == "pz_encoding":
circuit = pz_encoding
elif ansatz == "rpz_encoding":
circuit = rpz_enocding
elif ansatz == "px_encoding":
circuit = px_encoding
elif callable(ansatz):
circuit = ansatz
else:
raise NotImplementedError()
x = circuit(theta) # shape: (batch_size, out_dim, in_dim)
return x