Source code for qkan.feynman

# MIT License
#
# Copyright (c) 2024 Ziming Liu
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
Get Feynman dataset from the given name.
Return the symbol, expression, function, and ranges of the dataset.

Adapted from [KindXiaoming/pykan@GitHub 11984f4](https://github.com/KindXiaoming/pykan/tree/11984f49216fc254c1dd40b5bee0a069ee63114a)
"""

from typing import Union as TyUnion

import torch
from sympy import *  # noqa: F403

dataset_range = range(120)


[docs] def get_feynman_dataset(name: TyUnion[str, int]): """ Get Feynman dataset from the given name. Args: name (str | int): The name of the dataset. Returns: tuple: The symbol, expression, function, and ranges of the dataset. """ tpi = torch.tensor(torch.pi) ranges: list[list[TyUnion[int, float, torch.Tensor]]] match name: case "test" | 0: symbol = x, y = symbols("x, y") expr = (x + y) * sin(exp(2 * y)) f = lambda x: (x[:, [0]] + x[:, [1]]) * torch.sin(torch.exp(2 * x[:, [1]])) ranges = [[-1, 1]] case "I.6.20a" | 1: symbol = theta = symbols("theta") symbol = [symbol] expr = exp(-(theta**2) / 2) / sqrt(2 * pi) f = lambda x: torch.exp(-(x[:, [0]] ** 2) / 2) / torch.sqrt(2 * tpi) ranges = [[-3, 3]] case "I.6.20" | 2: symbol = theta, sigma = symbols("theta sigma") expr = exp(-(theta**2) / (2 * sigma**2)) / sqrt(2 * pi * sigma**2) f = lambda x: torch.exp( -(x[:, [0]] ** 2) / (2 * x[:, [1]] ** 2) ) / torch.sqrt(2 * tpi * x[:, [1]] ** 2) ranges = [[-1, 1], [0.5, 2]] case "I.6.20b" | 3: symbol = theta, theta1, sigma = symbols("theta theta1 sigma") expr = exp(-((theta - theta1) ** 2) / (2 * sigma**2)) / sqrt( 2 * pi * sigma**2 ) f = lambda x: torch.exp( -((x[:, [0]] - x[:, [1]]) ** 2) / (2 * x[:, [2]] ** 2) ) / torch.sqrt(2 * tpi * x[:, [2]] ** 2) ranges = [[-1.5, 1.5], [-1.5, 1.5], [0.5, 2]] case "I.8.4" | 4: symbol = x1, x2, y1, y2 = symbols("x1 x2 y1 y2") expr = sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) f = lambda x: torch.sqrt( (x[:, [1]] - x[:, [0]]) ** 2 + (x[:, [3]] - x[:, [2]]) ** 2 ) ranges = [[-1, 1], [-1, 1], [-1, 1], [-1, 1]] case "I.9.18" | 5: symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols( "G m1 m2 x1 x2 y1 y2 z1 z2" ) expr = G * m1 * m2 / ((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2) f = ( lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / ( (x[:, [3]] - x[:, [4]]) ** 2 + (x[:, [5]] - x[:, [6]]) ** 2 + (x[:, [7]] - x[:, [8]]) ** 2 ) ) ranges = [ [-1, 1], [-1, 1], [-1, 1], [-1, -0.5], [0.5, 1], [-1, -0.5], [0.5, 1], [-1, -0.5], [0.5, 1], ] case "I.10.7" | 6: symbol = m0, v, c = symbols("m0 v c") expr = m0 / sqrt(1 - v**2 / c**2) f = lambda x: x[:, [0]] / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ranges = [[0, 1], [0, 1], [1, 2]] case "I.11.19" | 7: symbol = x1, y1, x2, y2, x3, y3 = symbols("x1 y1 x2 y2 x3 y3") expr = x1 * y1 + x2 * y2 + x3 * y3 f = ( lambda x: x[:, [0]] * x[:, [1]] + x[:, [2]] * x[:, [3]] + x[:, [4]] * x[:, [5]] ) ranges = [[-1, 1]] case "I.12.1" | 8: symbol = mu, Nn = symbols("mu N_n") expr = mu * Nn f = lambda x: x[:, [0]] * x[:, [1]] ranges = [[-1, 1]] case "I.12.2" | 9: symbol = q1, q2, eps, r = symbols("q1 q2 epsilon r") expr = q1 * q2 / (4 * pi * eps * r**2) f = lambda x: x[:, [0]] * x[:, [1]] / (4 * tpi * x[:, [2]] * x[:, [3]] ** 2) ranges = [[-1, 1], [-1, 1], [0.5, 2], [0.5, 2]] case "I.12.4" | 10: symbol = q1, eps, r = symbols("q1 epsilon r") expr = q1 / (4 * pi * eps * r**2) f = lambda x: x[:, [0]] / (4 * tpi * x[:, [1]] * x[:, [2]] ** 2) ranges = [[-1, 1], [0.5, 2], [0.5, 2]] case "I.12.5" | 11: symbol = q2, Ef = symbols("q2, E_f") expr = q2 * Ef f = lambda x: x[:, [0]] * x[:, [1]] ranges = [[-1, 1]] case "I.12.11" | 12: symbol = q, Ef, B, v, theta = symbols("q E_f B v theta") expr = q * (Ef + B * v * sin(theta)) f = lambda x: x[:, [0]] * ( x[:, [1]] + x[:, [2]] * x[:, [3]] * torch.sin(x[:, [4]]) ) ranges = [[-1, 1], [-1, 1], [-1, 1], [-1, 1], [0, 2 * tpi]] case "I.13.4" | 13: symbol = m, v, u, w = symbols("m u v w") expr = 1 / 2 * m * (v**2 + u**2 + w**2) f = ( lambda x: 1 / 2 * x[:, [0]] * (x[:, [1]] ** 2 + x[:, [2]] ** 2 + x[:, [3]] ** 2) ) ranges = [[-1, 1], [-1, 1], [-1, 1], [-1, 1]] case "I.13.12" | 14: symbol = G, m1, m2, r1, r2 = symbols("G m1 m2 r1 r2") expr = G * m1 * m2 * (1 / r2 - 1 / r1) f = ( lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] * (1 / x[:, [4]] - 1 / x[:, [3]]) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "I.14.3" | 15: symbol = m, g, z = symbols("m g z") expr = m * g * z f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] ranges = [[0, 1], [0, 1], [-1, 1]] case "I.14.4" | 16: symbol = ks, x = symbols("k_s x") expr = 1 / 2 * ks * x**2 f = lambda x: 1 / 2 * x[:, [0]] * x[:, [1]] ** 2 ranges = [[0, 1], [-1, 1]] case "I.15.3x" | 17: symbol = x, u, t, c = symbols("x u t c") expr = (x - u * t) / sqrt(1 - u**2 / c**2) f = lambda x: (x[:, [0]] - x[:, [1]] * x[:, [2]]) / torch.sqrt( 1 - x[:, [1]] ** 2 / x[:, [3]] ** 2 ) ranges = [[-1, 1], [-1, 1], [-1, 1], [1, 2]] case "I.15.3t" | 18: symbol = t, u, x, c = symbols("t u x c") expr = (t - u * x / c**2) / sqrt(1 - u**2 / c**2) f = lambda x: ( x[:, [0]] - x[:, [1]] * x[:, [2]] / x[:, [3]] ** 2 ) / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [3]] ** 2) ranges = [[-1, 1], [-1, 1], [-1, 1], [1, 2]] case "I.15.10" | 19: symbol = m0, v, c = symbols("m0 v c") expr = m0 * v / sqrt(1 - v**2 / c**2) f = ( lambda x: x[:, [0]] * x[:, [1]] / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ) ranges = [[-1, 1], [-0.9, 0.9], [1.1, 2]] case "I.16.6" | 20: symbol = u, v, c = symbols("u v c") expr = (u + v) / (1 + u * v / c**2) f = ( lambda x: x[:, [0]] * x[:, [1]] / (1 + x[:, [0]] * x[:, [1]] / x[:, [2]] ** 2) ) ranges = [[-0.8, 0.8], [-0.8, 0.8], [1, 2]] case "I.18.4" | 21: symbol = m1, r1, m2, r2 = symbols("m1 r1 m2 r2") expr = (m1 * r1 + m2 * r2) / (m1 + m2) f = lambda x: (x[:, [0]] * x[:, [1]] + x[:, [2]] * x[:, [3]]) / ( x[:, [0]] + x[:, [2]] ) ranges = [[0.5, 1], [-1, 1], [0.5, 1], [-1, 1]] case "I.18.4" | 22: symbol = r, F, theta = symbols("r F theta") expr = r * F * sin(theta) f = lambda x: x[:, [0]] * x[:, [1]] * torch.sin(x[:, [2]]) ranges = [[-1, 1], [-1, 1], [0, 2 * tpi]] case "I.18.16" | 23: symbol = m, r, v, theta = symbols("m r v theta") expr = m * r * v * sin(theta) f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] * torch.sin(x[:, [3]]) ranges = [[-1, 1], [-1, 1], [-1, 1], [0, 2 * tpi]] case "I.24.6" | 24: symbol = m, omega, omega0, x = symbols("m omega omega_0 x") expr = 1 / 4 * m * (omega**2 + omega0**2) * x**2 f = ( lambda x: 1 / 4 * x[:, [0]] * (x[:, [1]] ** 2 + x[:, [2]] ** 2) * x[:, [3]] ** 2 ) ranges = [[0, 1], [-1, 1], [-1, 1], [-1, 1]] case "I.25.13" | 25: symbol = q, C = symbols("q C") expr = q / C f = lambda x: x[:, [0]] / x[:, [1]] ranges = [[-1, 1], [0.5, 2]] case "I.26.2" | 26: symbol = n, theta2 = symbols("n theta2") expr = asin(n * sin(theta2)) f = lambda x: torch.arcsin(x[:, [0]] * torch.sin(x[:, [1]])) ranges = [[0, 0.99], [0, 2 * tpi]] case "I.27.6" | 27: symbol = d1, d2, n = symbols("d1 d2 n") expr = 1 / (1 / d1 + n / d2) f = lambda x: 1 / (1 / x[:, [0]] + x[:, [2]] / x[:, [1]]) ranges = [[0.5, 2], [1, 2], [0.5, 2]] case "I.29.4" | 28: symbol = omega, c = symbols("omega c") expr = omega / c f = lambda x: x[:, [0]] / x[:, [1]] ranges = [[0, 1], [0.5, 2]] case "I.29.16" | 29: symbol = x1, x2, theta1, theta2 = symbols("x1 x2 theta1 theta2") expr = sqrt(x1**2 + x2**2 - 2 * x1 * x2 * cos(theta1 - theta2)) f = lambda x: torch.sqrt( x[:, [0]] ** 2 + x[:, [1]] ** 2 - 2 * x[:, [0]] * x[:, [1]] * torch.cos(x[:, [2]] - x[:, [3]]) ) ranges = [[-1, 1], [-1, 1], [0, 2 * tpi], [0, 2 * tpi]] case "I.30.3" | 30: symbol = I0, n, theta = symbols("I_0 n theta") expr = I0 * sin(n * theta / 2) ** 2 / sin(theta / 2) ** 2 f = ( lambda x: x[:, [0]] * torch.sin(x[:, [1]] * x[:, [2]] / 2) ** 2 / torch.sin(x[:, [2]] / 2) ** 2 ) ranges = [[0, 1], [0, 4], [0.4 * tpi, 1.6 * tpi]] case "I.30.5" | 31: symbol = lamb, n, d = symbols("lambda n d") expr = asin(lamb / (n * d)) f = lambda x: torch.arcsin(x[:, [0]] / (x[:, [1]] * x[:, [2]])) ranges = [[-1, 1], [1, 1.5], [1, 1.5]] case "I.32.5" | 32: symbol = q, a, eps, c = symbols("q a epsilon c") expr = q**2 * a**2 / (eps * c**3) f = lambda x: x[:, [0]] ** 2 * x[:, [1]] ** 2 / (x[:, [2]] * x[:, [3]] ** 3) ranges = [[-1, 1], [-1, 1], [0.5, 2], [0.5, 2]] case "I.32.17" | 33: symbol = eps, c, Ef, r, omega, omega0 = symbols( "epsilon c E_f r omega omega_0" ) expr = nsimplify( (1 / 2 * eps * c * Ef**2) * (8 * pi * r**2 / 3) * (omega**4 / (omega**2 - omega0**2) ** 2) ) f = ( lambda x: (1 / 2 * x[:, [0]] * x[:, [1]] * x[:, [2]] ** 2) * (8 * tpi * x[:, [3]] ** 2 / 3) * (x[:, [4]] ** 4 / (x[:, [4]] ** 2 - x[:, [5]] ** 2) ** 2) ) ranges = [[0, 1], [0, 1], [-1, 1], [0, 1], [0, 1], [1, 2]] case "I.34.8" | 34: symbol = q, V, B, p = symbols("q V B p") expr = q * V * B / p f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[-1, 1], [-1, 1], [-1, 1], [0.5, 2]] case "I.34.10" | 35: symbol = omega0, v, c = symbols("omega_0 v c") expr = omega0 / (1 - v / c) f = lambda x: x[:, [0]] / (1 - x[:, [1]] / x[:, [2]]) ranges = [[0, 1], [0, 0.9], [1.1, 2]] case "I.34.14" | 36: symbol = omega0, v, c = symbols("omega_0 v c") expr = omega0 * (1 + v / c) / sqrt(1 - v**2 / c**2) f = ( lambda x: x[:, [0]] * (1 + x[:, [1]] / x[:, [2]]) / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ) ranges = [[0, 1], [-0.9, 0.9], [1.1, 2]] case "I.34.27" | 37: symbol = hbar, omega = symbols("hbar omega") expr = hbar * omega f = lambda x: x[:, [0]] * x[:, [1]] ranges = [[-1, 1], [-1, 1]] case "I.37.4" | 38: symbol = I1, I2, delta = symbols("I_1 I_2 delta") expr = I1 + I2 + 2 * sqrt(I1 * I2) * cos(delta) f = ( lambda x: x[:, [0]] + x[:, [1]] + 2 * torch.sqrt(x[:, [0]] * x[:, [1]]) * torch.cos(x[:, [2]]) ) ranges = [[0.1, 1], [0.1, 1], [0, 2 * tpi]] case "I.38.12" | 39: symbol = eps, hbar, m, q = symbols("epsilon hbar m q") expr = 4 * pi * eps * hbar**2 / (m * q**2) f = ( lambda x: 4 * tpi * x[:, [0]] * x[:, [1]] ** 2 / (x[:, [2]] * x[:, [3]] ** 2) ) ranges = [[0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "I.39.10" | 40: symbol = pF, V = symbols("p_F V") expr = 3 / 2 * pF * V f = lambda x: 3 / 2 * x[:, [0]] * x[:, [1]] ranges = [[0, 1], [0, 1]] case "I.39.11" | 41: symbol = gamma, pF, V = symbols("gamma p_F V") expr = pF * V / (gamma - 1) f = lambda x: 1 / (x[:, [0]] - 1) * x[:, [1]] * x[:, [2]] ranges = [[1.5, 3], [0, 1], [0, 1]] case "I.39.22" | 42: symbol = n, kb, T, V = symbols("n k_b T V") expr = n * kb * T / V f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "I.40.1" | 43: symbol = n0, m, g, x, kb, T = symbols("n_0 m g x k_b T") expr = n0 * exp(-m * g * x / (kb * T)) f = lambda x: x[:, [0]] * torch.exp( -x[:, [1]] * x[:, [2]] * x[:, [3]] / (x[:, [4]] * x[:, [5]]) ) ranges = [[0, 1], [-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2]] case "I.41.16" | 44: symbol = hbar, omega, c, kb, T = symbols("hbar omega c k_b T") expr = hbar * omega**3 / (pi**2 * c**2 * (exp(hbar * omega / (kb * T)) - 1)) f = ( lambda x: x[:, [0]] * x[:, [1]] ** 3 / ( tpi**2 * x[:, [2]] ** 2 * (torch.exp(x[:, [0]] * x[:, [1]] / (x[:, [3]] * x[:, [4]])) - 1) ) ) ranges = [[0.5, 1], [0.5, 1], [0.5, 2], [0.5, 2], [0.5, 2]] case "I.43.16" | 45: symbol = mu, q, Ve, d = symbols("mu q V_e d") expr = mu * q * Ve / d f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "I.43.31" | 46: symbol = mu, kb, T = symbols("mu k_b T") expr = mu * kb * T f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] ranges = [[0, 1], [0, 1], [0, 1]] case "I.43.43" | 47: symbol = gamma, kb, v, A = symbols("gamma k_b v A") expr = kb * v / A / (gamma - 1) f = lambda x: 1 / (x[:, [0]] - 1) * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[1.5, 3], [0, 1], [0, 1], [0.5, 2]] case "I.44.4" | 48: symbol = n, kb, T, V1, V2 = symbols("n k_b T V_1 V_2") expr = n * kb * T * log(V2 / V1) f = ( lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] * torch.log(x[:, [4]] / x[:, [3]]) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "I.47.23" | 49: symbol = gamma, p, rho = symbols("gamma p rho") expr = sqrt(gamma * p / rho) f = lambda x: torch.sqrt(x[:, [0]] * x[:, [1]] / x[:, [2]]) ranges = [[0.1, 1], [0.1, 1], [0.5, 2]] case "I.48.20" | 50: symbol = m, v, c = symbols("m v c") expr = m * c**2 / sqrt(1 - v**2 / c**2) f = ( lambda x: x[:, [0]] * x[:, [2]] ** 2 / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ) ranges = [[0, 1], [-0.9, 0.9], [1.1, 2]] case "I.50.26" | 51: symbol = x1, alpha, omega, t = symbols("x_1 alpha omega t") expr = x1 * (cos(omega * t) + alpha * cos(omega * t) ** 2) f = lambda x: x[:, [0]] * ( torch.cos(x[:, [2]] * x[:, [3]]) + x[:, [1]] * torch.cos(x[:, [2]] * x[:, [3]]) ** 2 ) ranges = [[0, 1], [0, 1], [0, 2 * tpi], [0, 1]] case "II.2.42" | 52: symbol = kappa, T1, T2, A, d = symbols("kappa T_1 T_2 A d") expr = kappa * (T2 - T1) * A / d f = lambda x: x[:, [0]] * (x[:, [2]] - x[:, [1]]) * x[:, [3]] / x[:, [4]] ranges = [[0, 1], [0, 1], [0, 1], [0, 1], [0.5, 2]] case "II.3.24" | 53: symbol = P, r = symbols("P r") expr = P / (4 * pi * r**2) f = lambda x: x[:, [0]] / (4 * tpi * x[:, [1]] ** 2) ranges = [[0, 1], [0.5, 2]] case "II.4.23" | 54: symbol = q, eps, r = symbols("q epsilon r") expr = q / (4 * pi * eps * r) f = lambda x: x[:, [0]] / (4 * tpi * x[:, [1]] * x[:, [2]]) ranges = [[0, 1], [0.5, 2], [0.5, 2]] case "II.6.11" | 55: symbol = eps, pd, theta, r = symbols("epsilon p_d theta r") expr = 1 / (4 * pi * eps) * pd * cos(theta) / r**2 f = ( lambda x: 1 / (4 * tpi * x[:, [0]]) * x[:, [1]] * torch.cos(x[:, [2]]) / x[:, [3]] ** 2 ) ranges = [[0.5, 2], [0, 1], [0, 2 * tpi], [0.5, 2]] case "II.6.15a" | 56: symbol = eps, pd, z, x, y, r = symbols("epsilon p_d z x y r") expr = 3 / (4 * pi * eps) * pd * z / r**5 * sqrt(x**2 + y**2) f = ( lambda x: 3 / (4 * tpi * x[:, [0]]) * x[:, [1]] * x[:, [2]] / x[:, [5]] ** 5 * torch.sqrt(x[:, [3]] ** 2 + x[:, [4]] ** 2) ) ranges = [[0.5, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0.5, 2]] case "II.6.15b" | 57: symbol = eps, pd, r, theta = symbols("epsilon p_d r theta") expr = 3 / (4 * pi * eps) * pd / r**3 * cos(theta) * sin(theta) f = ( lambda x: 3 / (4 * tpi * x[:, [0]]) * x[:, [1]] / x[:, [2]] ** 3 * torch.cos(x[:, [3]]) * torch.sin(x[:, [3]]) ) ranges = [[0.5, 2], [0, 1], [0.5, 2], [0, 2 * tpi]] case "II.8.7" | 58: symbol = q, eps, d = symbols("q epsilon d") expr = 3 / 5 * q**2 / (4 * pi * eps * d) f = lambda x: 3 / 5 * x[:, [0]] ** 2 / (4 * tpi * x[:, [1]] * x[:, [2]]) ranges = [[0, 1], [0.5, 2], [0.5, 2]] case "II.8.31" | 59: symbol = eps, Ef = symbols("epsilon E_f") expr = 1 / 2 * eps * Ef**2 f = lambda x: 1 / 2 * x[:, [0]] * x[:, [1]] ** 2 ranges = [[0, 1], [0, 1]] case "I.10.9" | 60: symbol = sigma, eps, chi = symbols("sigma epsilon chi") expr = sigma / eps / (1 + chi) f = lambda x: x[:, [0]] / x[:, [1]] / (1 + x[:, [2]]) ranges = [[0, 1], [0.5, 2], [0, 1]] case "II.11.3" | 61: symbol = q, Ef, m, omega0, omega = symbols("q E_f m omega_o omega") expr = q * Ef / (m * (omega0**2 - omega**2)) f = ( lambda x: x[:, [0]] * x[:, [1]] / (x[:, [2]] * (x[:, [3]] ** 2 - x[:, [4]] ** 2)) ) ranges = [[0, 1], [0, 1], [0.5, 2], [1.5, 3], [0, 1]] case "II.11.17" | 62: symbol = n0, pd, Ef, theta, kb, T = symbols("n_0 p_d E_f theta k_b T") expr = n0 * (1 + pd * Ef * cos(theta) / (kb * T)) f = lambda x: x[:, [0]] * ( 1 + x[:, [1]] * x[:, [2]] * torch.cos(x[:, [3]]) / (x[:, [4]] * x[:, [5]]) ) ranges = [[0, 1], [-1, 1], [-1, 1], [0, 2 * tpi], [0.5, 2], [0.5, 2]] case "II.11.20" | 63: symbol = n, pd, Ef, kb, T = symbols("n p_d E_f k_b T") expr = n * pd**2 * Ef / (3 * kb * T) f = ( lambda x: x[:, [0]] * x[:, [1]] ** 2 * x[:, [2]] / (3 * x[:, [3]] * x[:, [4]]) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "II.11.27" | 64: symbol = n, alpha, eps, Ef = symbols("n alpha epsilon E_f") expr = n * alpha / (1 - n * alpha / 3) * eps * Ef f = ( lambda x: x[:, [0]] * x[:, [1]] / (1 - x[:, [0]] * x[:, [1]] / 3) * x[:, [2]] * x[:, [3]] ) ranges = [[0, 1], [0, 2], [0, 1], [0, 1]] case "II.11.28" | 65: symbol = n, alpha = symbols("n alpha") expr = 1 + n * alpha / (1 - n * alpha / 3) f = lambda x: 1 + x[:, [0]] * x[:, [1]] / (1 - x[:, [0]] * x[:, [1]] / 3) ranges = [[0, 1], [0, 2]] case "II.13.17" | 66: symbol = eps, c, l, r = symbols("epsilon c l r") expr = 1 / (4 * pi * eps * c**2) * (2 * l / r) f = ( lambda x: 1 / (4 * tpi * x[:, [0]] * x[:, [1]] ** 2) * (2 * x[:, [2]] / x[:, [3]]) ) ranges = [[0.5, 2], [0.5, 2], [0, 1], [0.5, 2]] case "II.13.23" | 67: symbol = rho, v, c = symbols("rho v c") expr = rho / sqrt(1 - v**2 / c**2) f = lambda x: x[:, [0]] / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ranges = [[0, 1], [0, 1], [1, 2]] case "II.13.34" | 68: symbol = rho, v, c = symbols("rho v c") expr = rho * v / sqrt(1 - v**2 / c**2) f = ( lambda x: x[:, [0]] * x[:, [1]] / torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) ) ranges = [[0, 1], [0, 1], [1, 2]] case "II.15.4" | 69: symbol = muM, B, theta = symbols("mu_M B theta") expr = -muM * B * cos(theta) f = lambda x: -x[:, [0]] * x[:, [1]] * torch.cos(x[:, [2]]) ranges = [[0, 1], [0, 1], [0, 2 * tpi]] case "II.15.5" | 70: symbol = pd, Ef, theta = symbols("p_d E_f theta") expr = -pd * Ef * cos(theta) f = lambda x: -x[:, [0]] * x[:, [1]] * torch.cos(x[:, [2]]) ranges = [[0, 1], [0, 1], [0, 2 * tpi]] case "II.21.32" | 71: symbol = q, eps, r, v, c = symbols("q epsilon r v c") expr = q / (4 * pi * eps * r * (1 - v / c)) f = lambda x: x[:, [0]] / ( 4 * tpi * x[:, [1]] * x[:, [2]] * (1 - x[:, [3]] / x[:, [4]]) ) ranges = [[0, 1], [0.5, 2], [0.5, 2], [0, 1], [1, 2]] case "II.24.17" | 72: symbol = omega, c, d = symbols("omega c d") expr = sqrt(omega**2 / c**2 - pi**2 / d**2) f = lambda x: torch.sqrt( x[:, [0]] ** 2 / x[:, [1]] ** 2 - tpi**2 / x[:, [2]] ** 2 ) ranges = [[1, 1.5], [0.75, 1], [1 * tpi, 1.5 * tpi]] case "II.27.16" | 73: symbol = eps, c, Ef = symbols("epsilon c E_f") expr = eps * c * Ef**2 f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] ** 2 ranges = [[0, 1], [0, 1], [-1, 1]] case "II.27.18" | 74: symbol = eps, Ef = symbols("epsilon E_f") expr = eps * Ef**2 f = lambda x: x[:, [0]] * x[:, [1]] ** 2 ranges = [[0, 1], [-1, 1]] case "II.34.2a" | 75: symbol = q, v, r = symbols("q v r") expr = q * v / (2 * pi * r) f = lambda x: x[:, [0]] * x[:, [1]] / (2 * tpi * x[:, [2]]) ranges = [[0, 1], [0, 1], [0.5, 2]] case "II.34.2" | 76: symbol = q, v, r = symbols("q v r") expr = q * v * r / 2 f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / 2 ranges = [[0, 1], [0, 1], [0, 1]] case "II.34.11" | 77: symbol = g, q, B, m = symbols("g q B m") expr = g * q * B / (2 * m) f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / (2 * x[:, [3]]) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "II.34.29a" | 78: symbol = q, h, m = symbols("q h m") expr = q * h / (4 * pi * m) f = lambda x: x[:, [0]] * x[:, [1]] / (4 * tpi * x[:, [2]]) ranges = [[0, 1], [0, 1], [0.5, 2]] case "II.34.29b" | 79: symbol = g, mu, B, J, hbar = symbols("g mu B J hbar") expr = g * mu * B * J / hbar f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] * x[:, [3]] / x[:, [4]] ranges = [[0, 1], [0, 1], [0, 1], [0, 1], [0.5, 2]] case "II.35.18" | 80: symbol = n0, mu, B, kb, T = symbols("n0 mu B k_b T") expr = n0 / (exp(mu * B / (kb * T)) + exp(-mu * B / (kb * T))) f = lambda x: x[:, [0]] / ( torch.exp(x[:, [1]] * x[:, [2]] / (x[:, [3]] * x[:, [4]])) + torch.exp(-x[:, [1]] * x[:, [2]] / (x[:, [3]] * x[:, [4]])) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "II.35.21" | 81: symbol = n, mu, B, kb, T = symbols("n mu B k_b T") expr = n * mu * tanh(mu * B / (kb * T)) f = ( lambda x: x[:, [0]] * x[:, [1]] * torch.tanh(x[:, [1]] * x[:, [2]] / (x[:, [3]] * x[:, [4]])) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "II.36.38" | 82: symbol = mu, B, kb, T, alpha, M, eps, c = symbols( "mu B k_b T alpha M epsilon c" ) expr = mu * B / (kb * T) + mu * alpha * M / (eps * c**2 * kb * T) f = lambda x: x[:, [0]] * x[:, [1]] / (x[:, [2]] * x[:, [3]]) + x[ :, [0] ] * x[:, [4]] * x[:, [5]] / ( x[:, [6]] * x[:, [7]] ** 2 * x[:, [2]] * x[:, [3]] ) ranges = [ [0, 1], [0, 1], [0.5, 2], [0.5, 2], [0, 1], [0, 1], [0.5, 2], [0.5, 2], ] case "II.37.1" | 83: symbol = mu, chi, B = symbols("mu chi B") expr = mu * (1 + chi) * B f = lambda x: x[:, [0]] * (1 + x[:, [1]]) * x[:, [2]] ranges = [[0, 1], [0, 1], [0, 1]] case "II.38.3" | 84: symbol = Y, A, x, d = symbols("Y A x d") expr = Y * A * x / d f = lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "II.38.14" | 85: symbol = Y, sigma = symbols("Y sigma") expr = Y / (2 * (1 + sigma)) f = lambda x: x[:, [0]] / (2 * (1 + x[:, [1]])) ranges = [[0, 1], [0, 1]] case "III.4.32" | 86: symbol = hbar, omega, kb, T = symbols("hbar omega k_b T") expr = 1 / (exp(hbar * omega / (kb * T)) - 1) f = lambda x: 1 / ( torch.exp(x[:, [0]] * x[:, [1]] / (x[:, [2]] * x[:, [3]])) - 1 ) ranges = [[0.5, 1], [0.5, 1], [0.5, 2], [0.5, 2]] case "III.4.33" | 87: symbol = hbar, omega, kb, T = symbols("hbar omega k_b T") expr = hbar * omega / (exp(hbar * omega / (kb * T)) - 1) f = ( lambda x: x[:, [0]] * x[:, [1]] / (torch.exp(x[:, [0]] * x[:, [1]] / (x[:, [2]] * x[:, [3]])) - 1) ) ranges = [[0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "III.7.38" | 88: symbol = mu, B, hbar = symbols("mu B hbar") expr = 2 * mu * B / hbar f = lambda x: 2 * x[:, [0]] * x[:, [1]] / x[:, [2]] ranges = [[0, 1], [0, 1], [0.5, 2]] case "III.8.54" | 89: symbol = E, t, hbar = symbols("E t hbar") expr = sin(E * t / hbar) ** 2 f = lambda x: torch.sin(x[:, [0]] * x[:, [1]] / x[:, [2]]) ** 2 ranges = [[0, 2 * tpi], [0, 1], [0.5, 2]] case "III.9.52" | 90: symbol = pd, Ef, t, hbar, omega, omega0 = symbols( "p_d E_f t hbar omega omega_0" ) expr = ( pd * Ef * t / hbar * sin((omega - omega0) * t / 2) ** 2 / ((omega - omega0) * t / 2) ** 2 ) f = ( lambda x: x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] * torch.sin((x[:, [4]] - x[:, [5]]) * x[:, [2]] / 2) ** 2 / ((x[:, [4]] - x[:, [5]]) * x[:, [2]] / 2) ** 2 ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0, tpi], [0, tpi]] case "III.10.19" | 91: symbol = mu, Bx, By, Bz = symbols("mu B_x B_y B_z") expr = mu * sqrt(Bx**2 + By**2 + Bz**2) f = lambda x: x[:, [0]] * torch.sqrt( x[:, [1]] ** 2 + x[:, [2]] ** 2 + x[:, [3]] ** 2 ) ranges = [[0, 1], [0, 1], [0, 1], [0, 1]] case "III.12.43" | 92: symbol = n, hbar = symbols("n hbar") expr = n * hbar f = lambda x: x[:, [0]] * x[:, [1]] ranges = [[0, 1], [0, 1]] case "III.13.18" | 93: symbol = E, d, k, hbar = symbols("E d k hbar") expr = 2 * E * d**2 * k / hbar f = lambda x: 2 * x[:, [0]] * x[:, [1]] ** 2 * x[:, [2]] / x[:, [3]] ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "III.14.14" | 94: symbol = I0, q, Ve, kb, T = symbols("I_0 q V_e k_b T") expr = I0 * (exp(q * Ve / (kb * T)) - 1) f = lambda x: x[:, [0]] * ( torch.exp(x[:, [1]] * x[:, [2]] / (x[:, [3]] * x[:, [4]])) - 1 ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2]] case "III.15.12" | 95: symbol = U, k, d = symbols("U k d") expr = 2 * U * (1 - cos(k * d)) f = lambda x: 2 * x[:, [0]] * (1 - torch.cos(x[:, [1]] * x[:, [2]])) ranges = [[0, 1], [0, 2 * tpi], [0, 1]] case "III.15.14" | 96: symbol = hbar, E, d = symbols("hbar E d") expr = hbar**2 / (2 * E * d**2) f = lambda x: x[:, [0]] ** 2 / (2 * x[:, [1]] * x[:, [2]] ** 2) ranges = [[0, 1], [0.5, 2], [0.5, 2]] case "III.15.27" | 97: symbol = alpha, n, d = symbols("alpha n d") expr = 2 * pi * alpha / (n * d) f = lambda x: 2 * tpi * x[:, [0]] / (x[:, [1]] * x[:, [2]]) ranges = [[0, 1], [0.5, 2], [0.5, 2]] case "III.17.37" | 98: symbol = beta, alpha, theta = symbols("beta alpha theta") expr = beta * (1 + alpha * cos(theta)) f = lambda x: x[:, [0]] * (1 + x[:, [1]] * torch.cos(x[:, [2]])) ranges = [[0, 1], [0, 1], [0, 2 * tpi]] case "III.19.51" | 99: symbol = m, q, eps, hbar, n = symbols("m q epsilon hbar n") expr = -m * q**4 / (2 * (4 * pi * eps) ** 2 * hbar**2) * 1 / n**2 f = ( lambda x: -x[:, [0]] * x[:, [1]] ** 4 / (2 * (4 * tpi * x[:, [2]]) ** 2 * x[:, [3]] ** 2) * 1 / x[:, [4]] ** 2 ) ranges = [[0, 1], [0, 1], [0.5, 2], [0.5, 2], [0.5, 2]] case "III.21.20" | 100: symbol = rho, q, A, m = symbols("rho q A m") expr = -rho * q * A / m f = lambda x: -x[:, [0]] * x[:, [1]] * x[:, [2]] / x[:, [3]] ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2]] case "Rutherforld scattering" | 101: symbol = Z1, Z2, alpha, hbar, c, E, theta = symbols( "Z_1 Z_2 alpha hbar c E theta" ) expr = (Z1 * Z2 * alpha * hbar * c / (4 * E * sin(theta / 2) ** 2)) ** 2 f = ( lambda x: ( x[:, [0]] * x[:, [1]] * x[:, [2]] * x[:, [3]] * x[:, [4]] / (4 * x[:, [5]] * torch.sin(x[:, [6]] / 2) ** 2) ) ** 2 ) ranges = [ [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0.5, 2], [0.1 * tpi, 0.9 * tpi], ] case "Friedman equation" | 102: symbol = G, rho, kf, c, af = symbols("G rho k_f c a_f") expr = sqrt(8 * pi * G / 3 * rho - kf * c**2 / af**2) f = lambda x: torch.sqrt( 8 * tpi * x[:, [0]] / 3 * x[:, [1]] - x[:, [2]] * x[:, [3]] ** 2 / x[:, [4]] ** 2 ) ranges = [[1, 2], [1, 2], [0, 1], [0, 1], [1, 2]] case "Compton scattering" | 103: symbol = E, m, c, theta = symbols("E m c theta") expr = E / (1 + E / (m * c**2) * (1 - cos(theta))) f = lambda x: x[:, [0]] / ( 1 + x[:, [0]] / (x[:, [1]] * x[:, [2]] ** 2) * (1 - torch.cos(x[:, [3]])) ) ranges = [[0, 1], [0.5, 2], [0.5, 2], [0, 2 * tpi]] case "Radiated gravitational wave power" | 104: symbol = G, c, m1, m2, r = symbols("G c m_1 m_2 r") expr = -32 / 5 * G**4 / c**5 * (m1 * m2) ** 2 * (m1 + m2) / r**5 f = ( lambda x: -32 / 5 * x[:, [0]] ** 4 / x[:, [1]] ** 5 * (x[:, [2]] * x[:, [3]]) ** 2 * (x[:, [2]] + x[:, [3]]) / x[:, [4]] ** 5 ) ranges = [[0, 1], [0.5, 2], [0, 1], [0, 1], [0.5, 2]] case "Relativistic aberration" | 105: symbol = theta2, v, c = symbols("theta_2 v c") expr = acos((cos(theta2) - v / c) / (1 - v / c * cos(theta2))) f = lambda x: torch.arccos( (torch.cos(x[:, [0]]) - x[:, [1]] / x[:, [2]]) / (1 - x[:, [1]] / x[:, [2]] * torch.cos(x[:, [0]])) ) ranges = [[0, tpi], [0, 1], [1, 2]] case "N-slit diffraction" | 106: symbol = I0, alpha, delta, N = symbols("I_0 alpha delta N") expr = ( I0 * (sin(alpha / 2) / (alpha / 2) * sin(N * delta / 2) / sin(delta / 2)) ** 2 ) f = ( lambda x: x[:, [0]] * ( torch.sin(x[:, [1]] / 2) / (x[:, [1]] / 2) * torch.sin(x[:, [3]] * x[:, [2]] / 2) / torch.sin(x[:, [2]] / 2) ) ** 2 ) ranges = [[0, 1], [0.1 * tpi, 0.9 * tpi], [0.1 * tpi, 0.9 * tpi], [0.5, 1]] case "Goldstein 3.16" | 107: symbol = m, E, U, L, r = symbols("m E U L r") expr = sqrt(2 / m * (E - U - L**2 / (2 * m * r**2))) f = lambda x: torch.sqrt( 2 / x[:, [0]] * ( x[:, [1]] - x[:, [2]] - x[:, [3]] ** 2 / (2 * x[:, [0]] * x[:, [4]] ** 2) ) ) ranges = [[1, 2], [2, 3], [0, 1], [0, 1], [1, 2]] case "Goldstein 3.55" | 108: symbol = m, kG, L, E, theta1, theta2 = symbols("m k_G L E theta_1 theta_2") expr = ( m * kG / L**2 * (1 + sqrt(1 + 2 * E * L**2 / (m * kG**2)) * cos(theta1 - theta2)) ) f = ( lambda x: x[:, [0]] * x[:, [1]] / x[:, [2]] ** 2 * ( 1 + torch.sqrt( 1 + 2 * x[:, [3]] * x[:, [2]] ** 2 / (x[:, [0]] * x[:, [1]] ** 2) ) * torch.cos(x[:, [4]] - x[:, [5]]) ) ) ranges = [[0.5, 2], [0.5, 2], [0.5, 2], [0, 1], [0, 2 * tpi], [0, 2 * tpi]] case "Goldstein 3.64 (ellipse)" | 109: symbol = d, alpha, theta1, theta2 = symbols("d alpha theta_1 theta_2") expr = d * (1 - alpha**2) / (1 + alpha * cos(theta2 - theta1)) f = ( lambda x: x[:, [0]] * (1 - x[:, [1]] ** 2) / (1 + x[:, [1]] * torch.cos(x[:, [2]] - x[:, [3]])) ) ranges = [[0, 1], [0, 0.9], [0, 2 * tpi], [0, 2 * tpi]] case "Goldstein 3.74 (Kepler)" | 110: symbol = d, G, m1, m2 = symbols("d G m_1 m_2") expr = 2 * pi * d ** (3 / 2) / sqrt(G * (m1 + m2)) f = ( lambda x: 2 * tpi * x[:, [0]] ** (3 / 2) / torch.sqrt(x[:, [1]] * (x[:, [2]] + x[:, [3]])) ) ranges = [[0, 1], [0.5, 2], [0.5, 2], [0.5, 2]] case "Goldstein 3.99" | 111: symbol = eps, E, L, m, Z1, Z2, q = symbols("epsilon E L m Z_1 Z_2 q") expr = sqrt(1 + 2 * eps**2 * E * L**2 / (m * (Z1 * Z2 * q**2) ** 2)) f = lambda x: torch.sqrt( 1 + 2 * x[:, [0]] ** 2 * x[:, [1]] * x[:, [2]] ** 2 / (x[:, [3]] * (x[:, [4]] * x[:, [5]] * x[:, [6]] ** 2) ** 2) ) ranges = [[0, 1], [0, 1], [0, 1], [0.5, 2], [0.5, 2], [0.5, 2], [0.5, 2]] case "Goldstein 8.56" | 112: symbol = p, q, A, c, m, Ve = symbols("p q A c m V_e") expr = sqrt((p - q * A) ** 2 * c**2 + m**2 * c**4) + q * Ve f = ( lambda x: torch.sqrt( (x[:, [0]] - x[:, [1]] * x[:, [2]]) ** 2 * x[:, [3]] ** 2 + x[:, [4]] ** 2 * x[:, [3]] ** 4 ) + x[:, [1]] * x[:, [5]] ) ranges = [[0, 1]] case "Goldstein 12.80" | 113: symbol = m, p, omega, x, alpha, y = symbols("m p omega x alpha y") expr = 1 / (2 * m) * (p**2 + m**2 * omega**2 * x**2 * (1 + alpha * y / x)) f = ( lambda x: 1 / (2 * x[:, [0]]) * ( x[:, [1]] ** 2 + x[:, [0]] ** 2 * x[:, [2]] ** 2 * x[:, [3]] ** 2 * (1 + x[:, [4]] * x[:, [3]] / x[:, [5]]) ) ) ranges = [[0.5, 2], [0, 1], [0, 1], [0, 1], [0, 1], [0.5, 2]] case "Jackson 2.11" | 114: symbol = q, eps, y, Ve, d = symbols("q epsilon y V_e d") expr = ( q / (4 * pi * eps * y**2) * (4 * pi * eps * Ve * d - q * d * y**3 / (y**2 - d**2) ** 2) ) f = ( lambda x: x[:, [0]] / (4 * tpi * x[:, [1]] * x[:, [2]] ** 2) * ( 4 * tpi * x[:, [1]] * x[:, [3]] * x[:, [4]] - x[:, [0]] * x[:, [4]] * x[:, [2]] ** 3 / (x[:, [2]] ** 2 - x[:, [4]] ** 2) ** 2 ) ) ranges = [[0, 1], [0.5, 2], [1, 2], [0, 1], [0, 1]] case "Jackson 3.45" | 115: symbol = q, r, d, alpha = symbols("q r d alpha") expr = q / sqrt(r**2 + d**2 - 2 * d * r * cos(alpha)) f = lambda x: x[:, [0]] / torch.sqrt( x[:, [1]] ** 2 + x[:, [2]] ** 2 - 2 * x[:, [1]] * x[:, [2]] * torch.cos(x[:, [3]]) ) ranges = [[0, 1], [0, 1], [0, 1], [0, 2 * tpi]] case "Jackson 4.60" | 116: symbol = Ef, theta, alpha, d, r = symbols("E_f theta alpha d r") expr = Ef * cos(theta) * ((alpha - 1) / (alpha + 2) * d**3 / r**2 - r) f = ( lambda x: x[:, [0]] * torch.cos(x[:, [1]]) * ( (x[:, [2]] - 1) / (x[:, [2]] + 2) * x[:, [3]] ** 3 / x[:, [4]] ** 2 - x[:, [4]] ) ) ranges = [[0, 1], [0, 2 * tpi], [0, 2], [0, 1], [0.5, 2]] case "Jackson 11.38 (Doppler)" | 117: symbol = omega, v, c, theta = symbols("omega v c theta") expr = sqrt(1 - v**2 / c**2) / (1 + v / c * cos(theta)) * omega f = ( lambda x: torch.sqrt(1 - x[:, [1]] ** 2 / x[:, [2]] ** 2) / (1 + x[:, [1]] / x[:, [2]] * torch.cos(x[:, [3]])) * x[:, [0]] ) ranges = [[0, 1], [0, 1], [1, 2], [0, 2 * tpi]] case "Weinberg 15.2.1" | 118: symbol = G, c, kf, af, H = symbols("G c k_f a_f H") expr = 3 / (8 * pi * G) * (c**2 * kf / af**2 + H**2) f = ( lambda x: 3 / (8 * tpi * x[:, [0]]) * (x[:, [1]] ** 2 * x[:, [2]] / x[:, [3]] ** 2 + x[:, [4]] ** 2) ) ranges = [[0.5, 2], [0, 1], [0, 1], [0.5, 2], [0, 1]] case "Weinberg 15.2.2" | 119: symbol = G, c, kf, af, H, alpha = symbols("G c k_f a_f H alpha") expr = ( -1 / (8 * pi * G) * (c**4 * kf / af**2 + c**2 * H**2 * (1 - 2 * alpha)) ) f = ( lambda x: -1 / (8 * tpi * x[:, [0]]) * ( x[:, [1]] ** 4 * x[:, [2]] / x[:, [3]] ** 2 + x[:, [1]] ** 2 * x[:, [4]] ** 2 * (1 - 2 * x[:, [5]]) ) ) ranges = [[0.5, 2], [0, 1], [0, 1], [0.5, 2], [0, 1], [0, 1]] case "Schwarz 13.132 (Klein-Nishina)" | 120: symbol = alpha, hbar, m, c, omega0, omega, theta = symbols( "alpha hbar m c omega_0 omega theta" ) expr = ( pi * alpha**2 * hbar**2 / m**2 / c**2 * (omega0 / omega) ** 2 * (omega0 / omega + omega / omega0 - sin(theta) ** 2) ) f = ( lambda x: tpi * x[:, [0]] ** 2 * x[:, [1]] ** 2 / x[:, [2]] ** 2 / x[:, [3]] ** 2 * (x[:, [4]] / x[:, [5]]) ** 2 * ( x[:, [4]] / x[:, [5]] + x[:, [5]] / x[:, [4]] - torch.sin(x[:, [6]]) ** 2 ) ) ranges = [ [0, 1], [0, 1], [0.5, 2], [0.5, 2], [0.5, 2], [0.5, 2], [0, 2 * tpi], ] return symbol, expr, f, ranges