Source code for qkan.utils

# MIT License
#
# Copyright (c) 2024 Ziming Liu
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
Create dataset for regression task.

Adapted from [KindXiaoming/pykan@GitHub 91a2f63](https://github.com/KindXiaoming/pykan/tree/91a2f633be2d435b081ef0ef52a7205c7e7bea9e)
"""

import numpy as np
import sympy
import torch

# sigmoid = sympy.Function('sigmoid')
# name: (torch implementation, sympy implementation)

# singularity protection functions
f_inv = lambda x, y_th: (
    (x_th := 1 / y_th),
    y_th / x_th * x * (torch.abs(x) < x_th)
    + torch.nan_to_num(1 / x) * (torch.abs(x) >= x_th),
)
f_inv2 = lambda x, y_th: (
    (x_th := 1 / y_th ** (1 / 2)),
    y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x**2) * (torch.abs(x) >= x_th),
)
f_inv3 = lambda x, y_th: (
    (x_th := 1 / y_th ** (1 / 3)),
    y_th / x_th * x * (torch.abs(x) < x_th)
    + torch.nan_to_num(1 / x**3) * (torch.abs(x) >= x_th),
)
f_inv4 = lambda x, y_th: (
    (x_th := 1 / y_th ** (1 / 4)),
    y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1 / x**4) * (torch.abs(x) >= x_th),
)
f_inv5 = lambda x, y_th: (
    (x_th := 1 / y_th ** (1 / 5)),
    y_th / x_th * x * (torch.abs(x) < x_th)
    + torch.nan_to_num(1 / x**5) * (torch.abs(x) >= x_th),
)
f_sqrt = lambda x, y_th: (
    (x_th := 1 / y_th**2),
    x_th / y_th * x * (torch.abs(x) < x_th)
    + torch.nan_to_num(torch.sqrt(torch.abs(x)) * torch.sign(x))
    * (torch.abs(x) >= x_th),
)
f_power1d5 = lambda x, y_th: torch.abs(x) ** 1.5
f_invsqrt = lambda x, y_th: (
    (x_th := 1 / y_th**2),
    y_th * (torch.abs(x) < x_th)
    + torch.nan_to_num(1 / torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th),
)
f_log = lambda x, y_th: (
    (x_th := torch.e ** (-y_th)),
    -y_th * (torch.abs(x) < x_th)
    + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th),
)
f_tan = lambda x, y_th: (
    (clip := x % torch.pi),
    (delta := torch.pi / 2 - torch.arctan(y_th)),
    -y_th / delta * (clip - torch.pi / 2) * (torch.abs(clip - torch.pi / 2) < delta)
    + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi / 2) >= delta),
)
f_arctanh = lambda x, y_th: (
    (delta := 1 - torch.tanh(y_th) + 1e-4),
    y_th * torch.sign(x) * (torch.abs(x) > 1 - delta)
    + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta),
)
f_arcsin = lambda x, y_th: (
    (),
    torch.pi / 2 * torch.sign(x) * (torch.abs(x) > 1)
    + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1),
)
f_arccos = lambda x, y_th: (
    (),
    torch.pi / 2 * (1 - torch.sign(x)) * (torch.abs(x) > 1)
    + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1),
)
f_exp = lambda x, y_th: (
    (x_th := torch.log(y_th)),
    y_th * (x > x_th) + torch.exp(x) * (x <= x_th),
)

SYMBOLIC_LIB = {
    "x": (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
    "x^2": (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)),
    "x^3": (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)),
    "x^4": (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)),
    "x^5": (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)),
    "1/x": (lambda x: 1 / x, lambda x: 1 / x, 2, f_inv),
    "1/x^2": (lambda x: 1 / x**2, lambda x: 1 / x**2, 2, f_inv2),
    "1/x^3": (lambda x: 1 / x**3, lambda x: 1 / x**3, 3, f_inv3),
    "1/x^4": (lambda x: 1 / x**4, lambda x: 1 / x**4, 4, f_inv4),
    "1/x^5": (lambda x: 1 / x**5, lambda x: 1 / x**5, 5, f_inv5),
    "sqrt": (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
    "x^0.5": (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
    "x^1.5": (
        lambda x: torch.sqrt(x) ** 3,
        lambda x: sympy.sqrt(x) ** 3,
        4,
        f_power1d5,
    ),
    "1/sqrt(x)": (
        lambda x: 1 / torch.sqrt(x),
        lambda x: 1 / sympy.sqrt(x),
        2,
        f_invsqrt,
    ),
    "1/x^0.5": (lambda x: 1 / torch.sqrt(x), lambda x: 1 / sympy.sqrt(x), 2, f_invsqrt),
    "exp": (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
    "log": (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
    "abs": (
        lambda x: torch.abs(x),
        lambda x: sympy.Abs(x),
        3,
        lambda x, y_th: ((), torch.abs(x)),
    ),
    "sin": (
        lambda x: torch.sin(x),
        lambda x: sympy.sin(x),
        2,
        lambda x, y_th: ((), torch.sin(x)),
    ),
    "cos": (
        lambda x: torch.cos(x),
        lambda x: sympy.cos(x),
        2,
        lambda x, y_th: ((), torch.cos(x)),
    ),
    "tan": (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
    "tanh": (
        lambda x: torch.tanh(x),
        lambda x: sympy.tanh(x),
        3,
        lambda x, y_th: ((), torch.tanh(x)),
    ),
    "sgn": (
        lambda x: torch.sign(x),
        lambda x: sympy.sign(x),
        3,
        lambda x, y_th: ((), torch.sign(x)),
    ),
    "arcsin": (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
    "arccos": (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
    "arctan": (
        lambda x: torch.arctan(x),
        lambda x: sympy.atan(x),
        4,
        lambda x, y_th: ((), torch.arctan(x)),
    ),
    "arctanh": (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
    "0": (lambda x: x * 0, lambda x: x * 0, 0, lambda x, y_th: ((), x * 0)),
    "gaussian": (
        lambda x: torch.exp(-(x**2)),
        lambda x: sympy.exp(-(x**2)),
        3,
        lambda x, y_th: ((), torch.exp(-(x**2))),
    ),
}


[docs] def create_dataset( f, n_var=2, f_mode="col", ranges=[-1, 1], train_num=1000, test_num=1000, normalize_input=False, normalize_label=False, device="cpu", seed=0, ): """ Create dataset Args: f: function the symbolic formula used to create the synthetic dataset ranges: list or np.array; shape (2,) or (n_var, 2) the range of input variables. Default: [-1,1]. train_num: int the number of training samples. Default: 1000. test_num: int the number of test samples. Default: 1000. normalize_input: bool If True, apply normalization to inputs. Default: False. normalize_label: bool If True, apply normalization to labels. Default: False. device: str device. Default: 'cpu'. seed: int random seed. Default: 0. Returns: dataset: dict Train/test inputs/labels are dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label'] """ np.random.seed(seed) torch.manual_seed(seed) if len(np.array(ranges).shape) == 1: ranges = np.array(ranges * n_var).reshape(n_var, 2) else: ranges = np.array(ranges) train_input = torch.zeros(train_num, n_var) test_input = torch.zeros(test_num, n_var) for i in range(n_var): train_input[:, i] = ( torch.rand( train_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0] ) test_input[:, i] = ( torch.rand( test_num, ) * (ranges[i, 1] - ranges[i, 0]) + ranges[i, 0] ) if f_mode == "col": train_label = f(train_input) test_label = f(test_input) elif f_mode == "row": train_label = f(train_input.T) test_label = f(test_input.T) else: print(f"f_mode {f_mode} not recognized") # if has only 1 dimension if len(train_label.shape) == 1: train_label = train_label.unsqueeze(dim=1) test_label = test_label.unsqueeze(dim=1) def normalize(data, mean, std): return (data - mean) / std if normalize_input: mean_input = torch.mean(train_input, dim=0, keepdim=True) std_input = torch.std(train_input, dim=0, keepdim=True) train_input = normalize(train_input, mean_input, std_input) test_input = normalize(test_input, mean_input, std_input) if normalize_label: mean_label = torch.mean(train_label, dim=0, keepdim=True) std_label = torch.std(train_label, dim=0, keepdim=True) train_label = normalize(train_label, mean_label, std_label) test_label = normalize(test_label, mean_label, std_label) dataset = {} dataset["train_input"] = train_input.to(device) dataset["test_input"] = test_input.to(device) dataset["train_label"] = train_label.to(device) dataset["test_label"] = test_label.to(device) return dataset