Source code for qkan.torch_qc

"""
Synchrous processing of quantum circuits with PyTorch.

Features:
    - Single-qubit quantum circuits (Faster than other libraries)
    - Two-qubit quantum circuits

Code author: Jiun-Cheng Jiang (Jim137@GitHub)
Contact: [jcjiang@phys.ntu.edu.tw](mailto:jcjiang@phys.ntu.edu.tw)
"""

import torch


[docs] class TorchGates:
[docs] @staticmethod def identity_gate(shape) -> torch.Tensor: """ shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ return torch.stack( [ torch.stack([torch.ones(*shape), torch.zeros(*shape)]), torch.stack([torch.zeros(*shape), torch.ones(*shape)]), ], )
[docs] @staticmethod def i_gate(shape): """ alias for identity_gate shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ pass
i_gate = identity_gate # noqa: F811
[docs] @staticmethod def rx_gate(theta: torch.Tensor, dtype=torch.complex64) -> torch.Tensor: """ theta: torch.Tensor, shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ cos = torch.cos(theta / 2).to(dtype) jsin = 1j * torch.sin(-theta / 2) return torch.stack( [ torch.stack([cos, jsin]), torch.stack([jsin, cos]), ], )
[docs] @staticmethod def ry_gate(theta: torch.Tensor, dtype=torch.complex64) -> torch.Tensor: """ theta: torch.Tensor, shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ cos = torch.cos(theta / 2) sin = torch.sin(theta / 2) return torch.stack( [ torch.stack([cos, -sin]), torch.stack([sin, cos]), ], ).to(dtype)
[docs] @staticmethod def rz_gate(theta: torch.Tensor, dtype=torch.complex64) -> torch.Tensor: """ theta: torch.Tensor, shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ exp = torch.exp(-0.5j * theta) zero = torch.zeros_like(theta) return torch.stack( [ torch.stack([exp, zero]), torch.stack([zero, torch.conj(exp)]), ], ).to(dtype)
[docs] @staticmethod def h_gate(shape, device, dtype=torch.complex64) -> torch.Tensor: """ shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ inv_sqrt2 = 1 / torch.sqrt(torch.ones(*shape, device=device) * 2) return torch.stack( [ torch.stack([inv_sqrt2, inv_sqrt2]), torch.stack([inv_sqrt2, -inv_sqrt2]), ], ).to(dtype)
[docs] @staticmethod def s_gate(shape) -> torch.Tensor: """ shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ return torch.stack( [ torch.stack([torch.ones(*shape), torch.zeros(*shape)]), torch.stack([torch.zeros(*shape), 1j * torch.ones(*shape)]), ], )
[docs] @staticmethod def acrx_gate(theta: torch.Tensor, dtype=torch.complex64) -> torch.Tensor: """ Complex extension of RX(acos(theta)) gate. *Note: Physically unrealizable.* theta: torch.Tensor, shape: (out_dim, in_dim) return: torch.Tensor, shape: (2, 2, out_dim, in_dim) """ sq = torch.square(theta).flatten() diag = torch.mul( torch.sqrt(torch.abs(1 - sq)), torch.where(sq < 1, 1j, 1) ).reshape(theta.shape) return torch.stack( [ torch.stack([theta, diag]), torch.stack([diag, theta]), ], ).to(dtype)
[docs] @staticmethod def tensor_product(gate, another_gate, dtype=torch.complex64): """ Compute tensor product of two gates. Arguments --------- :gate: torch.Tensor, shape: (2, 2, out_dim, in_dim) :another_gate: torch.Tensor, shape: (2, 2, out_dim, in_dim) return: torch.Tensor, shape: (4, 4, out_dim, in_dim) """ shape = gate.shape[2:] gate = gate.view(2, 2, -1) another_gate = another_gate.view(2, 2, -1) out = torch.empty( 4, 4, gate.shape[2], dtype=dtype, device=gate.device, ) for i in range(out.shape[2]): out[:, :, i] = torch.kron(gate[:, :, i], another_gate[:, :, i]) return out.view(4, 4, *shape)
[docs] @staticmethod def cx_gate(shape, control: int, device, dtype=torch.complex64) -> torch.Tensor: """ 2-qubits CX (CNOT) gate. shape: (out_dim, in_dim) control: int return: torch.Tensor, shape: (4, 4, out_dim, in_dim) """ assert control in (0, 1), "Control qubit must be 0 or 1." gate = torch.zeros(4, 4, *shape, dtype=dtype, device=device) gate[0, 0] = 1.0 gate[1, 1] = 1.0 gate[2, 3] = 1.0 gate[3, 2] = 1.0 if control == 1: gate = gate.transpose(0, 1) return gate
[docs] @staticmethod def cz_gate(shape, device, dtype=torch.complex64) -> torch.Tensor: """ 2-qubits CZ gate. shape: (out_dim, in_dim) control: int return: torch.Tensor, shape: (4, 4, out_dim, in_dim) """ gate = torch.zeros(4, 4, *shape, dtype=dtype, device=device) gate[0, 0] = 1.0 gate[1, 1] = 1.0 gate[2, 2] = 1.0 gate[3, 3] = -1.0 return gate
[docs] class StateVector: """ 1-qubit state vector. StateVector.state: torch.Tensor, shape: (batch_size, out_dim, in_dim, 2) """ state: torch.Tensor def __init__( self, batch_size: int, out_dim: int, in_dim: int, device="cpu", dtype=torch.complex64, ): self.device = device self.batch_size = batch_size self.out_dim = out_dim self.in_dim = in_dim self.state = torch.zeros( batch_size, out_dim, in_dim, 2, dtype=dtype, device=self.device ) self.state[:, :, :, 0] = 1.0 self.dtype = dtype
[docs] def measure_z(self, fast_measure: bool = True) -> torch.Tensor: """ Measure the state vector in the Z basis. Arguments --------- :fast_measure: bool, default: True. If True, for state |ψ⟩ = α|0⟩ + β|1⟩, return |α| - |β|; if False, return |α|^2 - |β|^2. Which is quantum-inspired method and faster when it is True. return: torch.Tensor, shape: (batch_size, out_dim, in_dim) """ return ( self.state[:, :, :, 0].abs() - self.state[:, :, :, 1].abs() if fast_measure else torch.square(self.state[:, :, :, 0].abs()) - torch.square(self.state[:, :, :, 1].abs()) )
[docs] def measure_x(self) -> torch.Tensor: """ Measure the state vector in the X basis. return: torch.Tensor, shape: (batch_size, out_dim, in_dim) """ tmp_state = StateVector(self.batch_size, self.out_dim, self.in_dim, self.device) tmp_state.state.copy_(self.state) tmp_state.h() return tmp_state.measure_z()
[docs] def measure_y(self) -> torch.Tensor: """ Measure the state vector in the Y basis. return: torch.Tensor, shape: (batch_size, out_dim, in_dim) """ tmp_state = StateVector(self.batch_size, self.out_dim, self.in_dim, self.device) tmp_state.state.copy_(self.state) tmp_state.s(is_dagger=True) tmp_state.h() return tmp_state.measure_z()
[docs] def s(self, is_dagger: bool = False): """ Apply Phase gate (or S gate) to the state vector. Arguments --------- :is_dagger: bool, default: False """ gate = TorchGates.s_gate(self.state.shape[1:3]).to(self.device) if is_dagger: gate = torch.conj_physical(gate).transpose(0, 1) self.state = torch.einsum("mnoi,boin->boim", gate, self.state)
[docs] def h(self, is_dagger: bool = False): """ Apply Hadamard gate to the state vector. Arguments --------- :is_dagger: bool, default: False """ gate = TorchGates.h_gate(self.state.shape[1:3], self.device, dtype=self.dtype) if is_dagger: gate = torch.conj_physical(gate).transpose(0, 1) self.state = torch.einsum("mnoi,boin->boim", gate, self.state)
[docs] def rx(self, theta: torch.Tensor, is_dagger: bool = False): """ Apply Rotation-X gate to the state vector. Arguments --------- :theta: torch.Tensor, shape: (out_dim, in_dim) :is_dagger: bool, default: False """ gate = TorchGates.rx_gate(theta, dtype=self.dtype) if is_dagger: gate = torch.conj_physical(gate).transpose(0, 1) self.state = torch.einsum("mnoi,boin->boim", gate, self.state)
[docs] def ry(self, theta: torch.Tensor, is_dagger: bool = False): """ Apply Rotation-Y gate to the state vector. Arguments --------- :theta: torch.Tensor, shape: (out_dim, in_dim) :is_dagger: bool, default: False """ gate = TorchGates.ry_gate(theta, dtype=self.dtype) if is_dagger: gate = torch.conj_physical(gate).transpose(0, 1) self.state = torch.einsum("mnoi,boin->boim", gate, self.state)
[docs] def rz(self, theta: torch.Tensor, is_dagger: bool = False): """ Apply Rotation-Z gate to the state vector. Arguments --------- :theta: torch.Tensor, shape: (out_dim, in_dim) :is_dagger: bool, default: False """ gate = TorchGates.rz_gate(theta, dtype=self.dtype) if is_dagger: gate = torch.conj_physical(gate).transpose(0, 1) self.state = torch.einsum("mnoi,boin->boim", gate, self.state)
class DQStateVector: """ 2-qubit state vector. DQStateVector.state: torch.Tensor, shape: (batch_size, out_dim, in_dim, 4) """ state: torch.Tensor def __init__( self, batch_size: int, out_dim: int, in_dim: int, device="cpu", dtype=torch.complex64, ): self.device = device self.batch_size = batch_size self.out_dim = out_dim self.in_dim = in_dim self.state = torch.zeros( batch_size, out_dim, in_dim, 4, dtype=dtype, device=self.device ) self.state[:, :, :, 0] = 1.0 def measure_z(self, target: int = 0) -> torch.Tensor: """ Measure the state vector in the Z basis. return: torch.Tensor, shape: (batch_size, out_dim, in_dim) """ if target == 0: return ( +self.state[:, :, :, 0].abs() - self.state[:, :, :, 1].abs() + self.state[:, :, :, 2].abs() - self.state[:, :, :, 3].abs() ) else: return ( +self.state[:, :, :, 0].abs() + self.state[:, :, :, 1].abs() - self.state[:, :, :, 2].abs() - self.state[:, :, :, 3].abs() ) def cx(self, control: int): """ Apply CX (CNOT) gate to the state vector. Arguments --------- :control: int """ cx_gate = TorchGates.cx_gate(self.state.shape[1:3], control, self.device) self.state = torch.einsum("mnoi,boin->boim", cx_gate, self.state) def cz(self): """ Apply CZ gate to the state vector. """ cz_gate = TorchGates.cz_gate(self.state.shape[1:3], self.device) self.state = torch.einsum("mnoi,boin->boim", cz_gate, self.state) def apply_gate(self, gate: torch.Tensor, target: int = 0): """ Apply a gate to the state vector. Arguments --------- :gate: torch.Tensor, shape: (4, 4, out_dim, in_dim) """ if target == 0: gate = TorchGates.tensor_product( gate, TorchGates.identity_gate(self.state.shape[1:3]) ) else: gate = TorchGates.tensor_product( TorchGates.identity_gate(self.state.shape[1:3]), gate ) self.state = torch.einsum("mnoi,boin->boim", gate, self.state) def apply_2gates(self, gate1: torch.Tensor, gate2: torch.Tensor): """ Apply two gates to the state vector. Arguments --------- :gate1: torch.Tensor, shape: (4, 4, out_dim, in_dim) :gate2: torch.Tensor, shape: (4, 4, out_dim, in_dim) """ gate = TorchGates.tensor_product(gate1, gate2) self.state = torch.einsum("mnoi,boin->boim", gate, self.state) def hh(self): """ Apply Hadamard gate to the state vector. Arguments --------- :is_dagger: bool, default: False """ h_gate = TorchGates.h_gate(self.state.shape[1:3], self.device) self.apply_2gates(h_gate, h_gate)