mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
622 lines
20 KiB
Python
622 lines
20 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
from typing import Callable, Union
|
|
|
|
from .typing import Numeric
|
|
from .tensor import TensorSSA
|
|
|
|
from cutlass._mlir.dialects import math, arith
|
|
from cutlass.cutlass_dsl import dsl_user_op
|
|
|
|
|
|
def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
|
|
"""Dispatch the function to either a TensorSSA or a Numeric(Float).
|
|
|
|
:param func: The function to dispatch
|
|
:param args: The input tensor or scalar
|
|
:param kwargs: Extra keyword arguments (loc, ip) forwarded to the MLIR op
|
|
"""
|
|
arg_type = type(args[0])
|
|
for arg in args:
|
|
if not isinstance(arg, TensorSSA) and (
|
|
not isinstance(arg, Numeric) or not type(arg).is_float
|
|
):
|
|
raise TypeError(
|
|
f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
|
|
)
|
|
if not isinstance(arg, arg_type):
|
|
raise TypeError(
|
|
f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
|
|
)
|
|
|
|
fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
|
|
if isinstance(args[0], TensorSSA):
|
|
return TensorSSA(
|
|
func(*args, fastmath=fastmath_flag, **kwargs), args[0].shape, args[0].dtype
|
|
)
|
|
else:
|
|
args = [a.ir_value() for a in args]
|
|
return func(*args, fastmath=fastmath_flag, **kwargs)
|
|
|
|
|
|
@dsl_user_op
|
|
def absf(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise absolute value of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the absolute value of each element in input tensor
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = absf(y) # Compute absolute value
|
|
"""
|
|
return _math_op(math.absf, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
def acos(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise arc cosine of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the arc cosine of each element in input tensor
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = acos(y) # Compute arc cosine
|
|
"""
|
|
return _math_op(math.acos, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def asin(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise arc sine of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the arc sine of each element in input tensor
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = asin(y) # Compute arc sine
|
|
"""
|
|
return _math_op(math.asin, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def atan(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise arc tangent of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the arc tangent of each element in input tensor
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = atan(y) # Compute arc tangent
|
|
"""
|
|
return _math_op(math.atan, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def atan2(
|
|
a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False,
|
|
*, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise arc tangent of two tensors.
|
|
|
|
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
|
|
between the positive x-axis and the point given by the coordinates (b, a).
|
|
|
|
:param a: First input tensor (y-coordinates)
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param b: Second input tensor (x-coordinates)
|
|
:type b: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the arc tangent of a/b element-wise
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
y = cute.make_rmem_tensor(ptr1, layout).load() # y coordinates
|
|
x = cute.make_rmem_tensor(ptr2, layout).load() # x coordinates
|
|
theta = atan2(y, x) # Compute angles
|
|
"""
|
|
return _math_op(math.atan2, fastmath, a, b, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def copysign(
|
|
a: Union[TensorSSA, Numeric],
|
|
b: Union[TensorSSA, Numeric],
|
|
fastmath: bool = False,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise copysign of two tensors.
|
|
|
|
Returns a value with the magnitude of ``a`` and the sign of ``b``.
|
|
|
|
:param a: Input tensor providing magnitude
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param b: Input tensor providing sign
|
|
:type b: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor where each element has the magnitude of ``a`` and the sign of ``b``
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
mag = cute.make_rmem_tensor(ptr1, layout).load() # magnitudes
|
|
sgn = cute.make_rmem_tensor(ptr2, layout).load() # signs
|
|
result = copysign(mag, sgn) # Combine magnitude and sign
|
|
"""
|
|
return _math_op(math.copysign, fastmath, a, b, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def cos(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise cosine of the input tensor.
|
|
|
|
:param a: Input tensor (in radians)
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the cosine of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = cos(y) # Compute cosine
|
|
"""
|
|
return _math_op(math.cos, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def erf(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise error function of the input tensor.
|
|
|
|
The error function is defined as:
|
|
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the error function value for each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = erf(y) # Compute error function
|
|
"""
|
|
return _math_op(math.erf, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def exp(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise exponential of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the exponential of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = exp(y) # Compute exponential
|
|
"""
|
|
return _math_op(math.exp, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def exp2(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise base-2 exponential of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing 2 raised to the power of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = exp2(y) # Compute 2^x
|
|
"""
|
|
return _math_op(math.exp2, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def floor(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise floor of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the largest integer less than or equal to each element in input tensor
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = floor(y) # Compute floor
|
|
"""
|
|
return _math_op(math.floor, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
def log(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise natural logarithm of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the natural logarithm of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = log(y) # Compute natural logarithm
|
|
"""
|
|
return _math_op(math.log, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def log2(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise base-2 logarithm of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the base-2 logarithm of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = log2(y) # Compute log base 2
|
|
"""
|
|
return _math_op(math.log2, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def log10(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise base-10 logarithm of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the base-10 logarithm of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = log10(y) # Compute log base 10
|
|
"""
|
|
return _math_op(math.log10, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def rsqrt(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise reciprocal square root of the input tensor.
|
|
|
|
Computes 1/√x element-wise.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the reciprocal square root of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = rsqrt(y) # Compute 1/√x
|
|
"""
|
|
return _math_op(math.rsqrt, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def sin(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise sine of the input tensor.
|
|
|
|
:param a: Input tensor (in radians)
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the sine of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = sin(y) # Compute sine
|
|
"""
|
|
return _math_op(math.sin, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def sqrt(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise square root of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the square root of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = sqrt(y) # Compute square root
|
|
"""
|
|
return _math_op(math.sqrt, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def tan(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise tangent of the input tensor.
|
|
|
|
:param a: Input tensor (in radians)
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the tangent of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = tan(y) # Compute tangent
|
|
"""
|
|
return _math_op(math.tan, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
@dsl_user_op
|
|
def tanh(
|
|
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
|
|
) -> Union[TensorSSA, Numeric]:
|
|
"""Compute element-wise hyperbolic tangent of the input tensor.
|
|
|
|
:param a: Input tensor
|
|
:type a: Union[TensorSSA, Numeric]
|
|
:param fastmath: Enable fast math optimizations, defaults to False
|
|
:type fastmath: bool, optional
|
|
:param loc: Source location information, defaults to None
|
|
:type loc: Optional[Location]
|
|
:param ip: Insertion point for IR generation, defaults to None
|
|
:type ip: Optional[InsertionPoint]
|
|
:return: Tensor containing the hyperbolic tangent of each element
|
|
:rtype: Union[TensorSSA, Numeric]
|
|
|
|
Example:
|
|
|
|
.. code-block::
|
|
|
|
x = cute.make_rmem_tensor(layout) # Create tensor
|
|
y = x.load() # Load values
|
|
z = tanh(y) # Compute hyperbolic tangent
|
|
"""
|
|
return _math_op(math.tanh, fastmath, a, loc=loc, ip=ip)
|
|
|
|
|
|
__all__ = [
|
|
"absf",
|
|
"acos",
|
|
"asin",
|
|
"atan",
|
|
"atan2",
|
|
"copysign",
|
|
"cos",
|
|
"erf",
|
|
"exp",
|
|
"exp2",
|
|
"floor",
|
|
"log",
|
|
"log10",
|
|
"log2",
|
|
"rsqrt",
|
|
"sin",
|
|
"sqrt",
|
|
"tan",
|
|
"tanh",
|
|
]
|