Files
cutlass/python/CuTeDSL/cutlass/cute/math.py
Nandor Licker ea46e277d2 Add absf and floor to cute.math (#3156)
The ops are already exposed by the underlying dialect.
2026-04-17 08:54:24 +08:00

622 lines
20 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Callable, Union
from .typing import Numeric
from .tensor import TensorSSA
from cutlass._mlir.dialects import math, arith
from cutlass.cutlass_dsl import dsl_user_op
def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
"""Dispatch the function to either a TensorSSA or a Numeric(Float).
:param func: The function to dispatch
:param args: The input tensor or scalar
:param kwargs: Extra keyword arguments (loc, ip) forwarded to the MLIR op
"""
arg_type = type(args[0])
for arg in args:
if not isinstance(arg, TensorSSA) and (
not isinstance(arg, Numeric) or not type(arg).is_float
):
raise TypeError(
f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
)
if not isinstance(arg, arg_type):
raise TypeError(
f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
)
fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
if isinstance(args[0], TensorSSA):
return TensorSSA(
func(*args, fastmath=fastmath_flag, **kwargs), args[0].shape, args[0].dtype
)
else:
args = [a.ir_value() for a in args]
return func(*args, fastmath=fastmath_flag, **kwargs)
@dsl_user_op
def absf(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise absolute value of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the absolute value of each element in input tensor
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = absf(y) # Compute absolute value
"""
return _math_op(math.absf, fastmath, a, loc=loc, ip=ip)
def acos(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc cosine of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the arc cosine of each element in input tensor
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = acos(y) # Compute arc cosine
"""
return _math_op(math.acos, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def asin(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc sine of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the arc sine of each element in input tensor
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = asin(y) # Compute arc sine
"""
return _math_op(math.asin, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def atan(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc tangent of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the arc tangent of each element in input tensor
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = atan(y) # Compute arc tangent
"""
return _math_op(math.atan, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def atan2(
a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False,
*, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc tangent of two tensors.
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
between the positive x-axis and the point given by the coordinates (b, a).
:param a: First input tensor (y-coordinates)
:type a: Union[TensorSSA, Numeric]
:param b: Second input tensor (x-coordinates)
:type b: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the arc tangent of a/b element-wise
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
y = cute.make_rmem_tensor(ptr1, layout).load() # y coordinates
x = cute.make_rmem_tensor(ptr2, layout).load() # x coordinates
theta = atan2(y, x) # Compute angles
"""
return _math_op(math.atan2, fastmath, a, b, loc=loc, ip=ip)
@dsl_user_op
def copysign(
a: Union[TensorSSA, Numeric],
b: Union[TensorSSA, Numeric],
fastmath: bool = False,
*,
loc=None,
ip=None,
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise copysign of two tensors.
Returns a value with the magnitude of ``a`` and the sign of ``b``.
:param a: Input tensor providing magnitude
:type a: Union[TensorSSA, Numeric]
:param b: Input tensor providing sign
:type b: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor where each element has the magnitude of ``a`` and the sign of ``b``
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
mag = cute.make_rmem_tensor(ptr1, layout).load() # magnitudes
sgn = cute.make_rmem_tensor(ptr2, layout).load() # signs
result = copysign(mag, sgn) # Combine magnitude and sign
"""
return _math_op(math.copysign, fastmath, a, b, loc=loc, ip=ip)
@dsl_user_op
def cos(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise cosine of the input tensor.
:param a: Input tensor (in radians)
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the cosine of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = cos(y) # Compute cosine
"""
return _math_op(math.cos, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def erf(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise error function of the input tensor.
The error function is defined as:
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the error function value for each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = erf(y) # Compute error function
"""
return _math_op(math.erf, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def exp(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise exponential of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the exponential of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = exp(y) # Compute exponential
"""
return _math_op(math.exp, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def exp2(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-2 exponential of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing 2 raised to the power of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = exp2(y) # Compute 2^x
"""
return _math_op(math.exp2, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def floor(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise floor of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the largest integer less than or equal to each element in input tensor
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = floor(y) # Compute floor
"""
return _math_op(math.floor, fastmath, a, loc=loc, ip=ip)
def log(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise natural logarithm of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the natural logarithm of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = log(y) # Compute natural logarithm
"""
return _math_op(math.log, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def log2(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-2 logarithm of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the base-2 logarithm of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = log2(y) # Compute log base 2
"""
return _math_op(math.log2, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def log10(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-10 logarithm of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the base-10 logarithm of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = log10(y) # Compute log base 10
"""
return _math_op(math.log10, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def rsqrt(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise reciprocal square root of the input tensor.
Computes 1/√x element-wise.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the reciprocal square root of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = rsqrt(y) # Compute 1/√x
"""
return _math_op(math.rsqrt, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def sin(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise sine of the input tensor.
:param a: Input tensor (in radians)
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the sine of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = sin(y) # Compute sine
"""
return _math_op(math.sin, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def sqrt(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise square root of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the square root of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = sqrt(y) # Compute square root
"""
return _math_op(math.sqrt, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def tan(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise tangent of the input tensor.
:param a: Input tensor (in radians)
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the tangent of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = tan(y) # Compute tangent
"""
return _math_op(math.tan, fastmath, a, loc=loc, ip=ip)
@dsl_user_op
def tanh(
a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise hyperbolic tangent of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:param loc: Source location information, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for IR generation, defaults to None
:type ip: Optional[InsertionPoint]
:return: Tensor containing the hyperbolic tangent of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_rmem_tensor(layout) # Create tensor
y = x.load() # Load values
z = tanh(y) # Compute hyperbolic tangent
"""
return _math_op(math.tanh, fastmath, a, loc=loc, ip=ip)
__all__ = [
"absf",
"acos",
"asin",
"atan",
"atan2",
"copysign",
"cos",
"erf",
"exp",
"exp2",
"floor",
"log",
"log10",
"log2",
"rsqrt",
"sin",
"sqrt",
"tan",
"tanh",
]