Release v4.0.0 (#2294)

This commit is contained in:
Kihiro Bando
2025-05-13 15:55:29 -04:00
committed by GitHub
parent ad7b2f5e84
commit f115c3f854
299 changed files with 51495 additions and 4413 deletions

188
python/CuTeDSL/EULA.txt Normal file
View File

@@ -0,0 +1,188 @@
NVIDIA Software License Agreement
IMPORTANT NOTICE PLEASE READ AND AGREE BEFORE USING THE SOFTWARE
This software license agreement (“Agreement”) is a legal agreement between you, whether an individual or entity, (“you”) and NVIDIA Corporation (“NVIDIA”) and governs the use of the NVIDIA CUTLASS DSLs software and materials that NVIDIA delivers to you under this Agreement (“Software”).
NVIDIA and you are each a “party” and collectively the “parties.”
This Agreement can be accepted only by an adult of legal age of majority in the country in which the Software is used.
If you dont have the required age or authority to accept this Agreement, or if you dont accept all the terms and conditions of this Agreement, do not use the Software.
1. License Grants
1.1. License Grant to You. The Software made available by NVIDIA to you is licensed, not sold.
Subject to the terms of this Agreement, NVIDIA grants you a limited, non-exclusive, revocable, non-transferable, and non-sublicensable (except as expressly granted in this Agreement), license to:
a. install and use copies of the Software,
b. configure the Software using configuration files provided (if applicable),
c. modify and create derivative works of any sample or example source code NVIDIA delivers to you as part of the Software (“Derivatives”) (if applicable), and
d. distribute python files in the Software package in source format as incorporated into a software application subject to the following distribution requirements:
i. Your application must have material additional functionality, beyond the included portions of the Software.
ii. The distributable portions of the Software shall only be accessed by your application.
iii. The following notice shall be included in modifications and derivative works of sample source code distributed: “This software contains source code provided by NVIDIA Corporation.”
iv. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
v. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIAs intellectual property rights.
vi. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
The foregoing (a) through (d) are, collectively, the “Purpose”, and the developed applications are only for use in systems with NVIDIA GPUs.
1.2. License Grant to NVIDIA. Subject to the terms of this Agreement, you grant NVIDIA and its affiliates a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit at NVIDIAs discretion any Derivatives created by or for you.
You may, but are not required to, deliver any Derivatives to NVIDIA.
2. License Restrictions
Your license to use the Software and Derivatives is restricted as stated in this Section 2 (“License Restrictions”).
You will cooperate with NVIDIA and, upon NVIDIAs written request, you will confirm in writing and provide reasonably requested information to verify your compliance with the terms of this Agreement.
You may not:
2.1. Use the Software or Derivatives for any purpose other than the Purpose;
2.2. Sell, rent, sublicense, transfer, distribute or otherwise make available to others (except authorized users as stated in Section 3 (“Authorized Users”)) any portion of the Software or Derivatives, except as expressly granted in Section 1.1 (“License Grant to You”);
2.3. Reverse engineer, decompile, or disassemble the Software components provided in binary form, nor attempt in any other manner to obtain source code of such Software;
2.4. Modify or create derivative works of the Software, except as expressly granted in Section 1.1 (“License Grant to You”);
2.5. Change or remove copyright or other proprietary notices in the Software;
2.6. Bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the Software;
2.7. Use the Software or Derivatives in any manner that would cause them to become subject to an open source software license, subject to the terms in Section 6 (“Components Under Other Licenses”);
2.8. Use the Software or Derivatives in violation of any applicable law or regulation in relevant jurisdictions
2.9. Indicate that a product or service developed with the Software or Derivatives is sponsored or endorsed by NVIDIA;
2.10. Replace any NVIDIA software components in the Software that are governed by this Agreement with other software that implements NVIDIA APIs;
2.11. Reverse engineer, decompile or disassemble any portion of the output generated using Software elements for the purpose of translating such output artifacts to target a non-NVIDIA platform; or
3. Authorized Users
You may allow employees and contractors of your entity or of your subsidiary(ies), and for educational institutions also enrolled students, to internally access and use the Software as authorized by this Agreement from your secure network to perform the work authorized by this Agreement on your behalf.
You are responsible for the compliance with the terms of this Agreement by your authorized users.
Any act or omission that if committed by you would constitute a breach of this Agreement will be deemed to constitute a breach of this Agreement if committed by your authorized users.
4. Pre-Release
Software versions identified as alpha, beta, preview, early access or otherwise as pre-release (“Pre-Release”) may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability and reliability standards relative to NVIDIA commercial offerings.
You use Pre-Release Software at your own risk. NVIDIA did not design or test the Software for use in production or business-critical systems.
NVIDIA may choose not to make available a commercial version of Pre-Release Software.
NVIDIA may also choose to abandon development and terminate the availability of Pre-Release Software at any time without liability.
5. Updates
NVIDIA may at any time and at its option, change, discontinue, or deprecate any part, or all, of the Software, or change or remove features or functionality, or make available patches, workarounds or other updates to the Software.
Unless the updates are provided with their separate governing terms, they are deemed part of the Software licensed to you under this Agreement, and your continued use of the Software is deemed acceptance of such changes.
6. Components Under Other Licenses
The Software may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms (“Other Licenses”).
The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights;
except that this Agreement will prevail regarding the use of third-party open source software, unless a third-party open source software license requires its license terms to prevail.
Open source software license means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (http://opensource.org), Free Software Foundation (http://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (http://www.spdx.org).
7. Ownership
7.1. NVIDIA Ownership. The Software, including all intellectual property rights, is and will remain the sole and exclusive property of NVIDIA or its licensors.
Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Software, and (b) no other license or right is granted to you by implication, estoppel or otherwise.
7.2. Your Ownership. Subject to the rights of NVIDIA and its suppliers in the Software, which continue to be licensed as stated in this Agreement, even when incorporated in your products or services, and the extent permitted by applicable law, as between you and NVIDIA, you hold all rights, title and interest in and to your products, services and Derivatives you develop as permitted in this Agreement including their respective intellectual property rights.
8. Feedback
You may, but you are not obligated to, provide suggestions, requests, fixes, modifications, enhancements, or other feedback regarding the Software (collectively, “Feedback”).
Feedback, even if designated as confidential by you, will not create any confidentiality obligation for NVIDIA or its affiliates.
If you provide Feedback, you grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIAs discretion.
9. Termination
9.1. Termination. This Agreement will automatically terminate without notice from NVIDIA if you fail to comply with any of the terms in this Agreement or if you commence or participate in any legal proceeding against NVIDIA with respect to the Software.
Additionally, either party may terminate this Agreement at any time with thirty (30) days advance written notice to the other party.
9.2. Effect of Termination. Upon any expiration or termination of this Agreement, you will promptly (a) stop using and return, delete or destroy NVIDIA confidential information and all Software received under this Agreement, and (b) delete or destroy Derivatives created under this Agreement, unless an authorized NVIDIA representative provides prior written approval that you may keep a copy of the Derivatives solely for archival purposes.
Upon written request, you will certify in writing that you have complied with your obligations under this Section 9.2 (“Effect of Termination”).
9.3. Survival. Section 1.2 (“License Grant to NVIDIA”), Section 5 (“Updates”), Section 6 (“Components Under Other Licenses”), Section 7 (“Ownership”), Section 8 (“Feedback), Section 9.2 (“Effect of Termination”), Section 9.3 (“Survival”), Section 10 (“Disclaimer of Warranties”), Section 11 (“Limitation of Liability”), Section 12 (“Use in Mission Critical Applications”), Section 13 (“Governing Law and Jurisdiction”), Section 14 (“Indemnity”) and Section 15 (“General”) will survive any expiration or termination of this Agreement.
10. Disclaimer of Warranties
THE SOFTWARE IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND REPRESENTATIONS OF ANY KIND, WHETHER
EXPRESS, IMPLIED OR STATUTORY, RELATING TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT OR ASSUME RESPONSIBILITY FOR THE ACCURACY OR COMPLETENESS OF ANY THIRD-PARTY INFORMATION, TEXT, GRAPHICS, LINKS CONTAINED IN THE SOFTWARE.
WITHOUT LIMITING THE FOREGOING, NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS, ANY DEFECTS OR ERRORS WILL BE CORRECTED, ANY CERTAIN CONTENT WILL BE AVAILABLE; OR THAT THE SOFTWARE IS FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS. NO INFORMATION OR ADVICE GIVEN BY NVIDIA WILL IN ANY WAY INCREASE THE SCOPE OF ANY WARRANTY EXPRESSLY PROVIDED IN THIS AGREEMENT.
NVIDIA does not warrant or assume responsibility for the accuracy or completeness of any third-party information, text, graphics or links contained in the Software.
11. Limitations of Liability
11.1. EXCLUSIONS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL NVIDIA BE LIABLE FOR ANY (I) INDIRECT, PUNITIVE, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES, OR (ii) DAMAGES FOR (a) THE COST OF PROCURING SUBSTITUTE GOODS, OR (b) LOSS OF PROFITS, REVENUES, USE, DATA OR GOODWILL ARISING OUT OF OR RELATED TO THIS AGREEMENT, WHETHER BASED ON BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, AND EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND EVEN IF A PARTYS REMEDIES FAIL THEIR ESSENTIAL PURPOSE.
11.2. DAMAGES CAP. ADDITIONALLY, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIAS TOTAL CUMULATIVE AGGREGATE LIABILITY FOR ANY AND ALL LIABILITIES, OBLIGATIONS OR CLAIMS ARISING OUT OF OR RELATED TO THIS AGREEMENT WILL NOT EXCEED FIVE U.S. DOLLARS (US$5).
12. Use in Mission Critical Applications
You acknowledge that the Software provided under this Agreement is not designed or tested by NVIDIA for use in any system or application where the use or failure of such system or application developed with NVIDIAs Software could result in injury, death or catastrophic damage (each, a “Mission Critical Application”).
Examples of Mission Critical Applications include use in avionics, navigation, autonomous vehicle applications, AI solutions for automotive products, military, medical, life support or other mission-critical or life-critical applications.
NVIDIA will not be liable to you or any third party, in whole or in part, for any claims or damages arising from these uses.
You are solely responsible for ensuring that systems and applications developed with the Software include sufficient safety and redundancy features and comply with all applicable legal and regulatory standards and requirements.
13. Governing Law and Jurisdiction
This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods.
The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts;
except that either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
14. Indemnity
By using the Software you agree to defend, indemnify and hold harmless NVIDIA and its affiliates and their respective officers, directors, employees and agents from and against any claims, disputes, demands, liabilities, damages, losses, costs and expenses arising out of or in any way connected with (i) products or services that have been developed or deployed with or use the Software, or claims that they violate laws, or infringe, violate, or misappropriate any third party right;
or (ii) use of the Software in breach of the terms of this Agreement.
15. General
15.1. Independent Contractors.
The parties are independent contractors, and this Agreement does not create a joint venture, partnership, agency, or other form of business association between the parties.
Neither party will have the power to bind the other party or incur any obligation on its behalf without the other partys prior written consent.
Nothing in this Agreement prevents either party from participating in similar arrangements with third parties.
15.2. No Assignment.
NVIDIA may assign, delegate or transfer its rights or obligations under this Agreement by any means or operation of law.
You may not, without NVIDIAs prior written consent, assign, delegate or transfer any of your rights or obligations under this Agreement by any means or operation of law, and any attempt to do so is null and void.
15.3. No Waiver.
No failure or delay by a party to enforce any term or obligation of this Agreement will operate as a waiver by that party, or prevent the enforcement of such term or obligation later.
15.4. Trade Compliance.
You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations.
You confirm (a) your understanding that export or reexport of certain NVIDIA products or technologies may require a license or other approval from appropriate authorities and (b) that you will not export or reexport any products or technology, directly or indirectly, without first obtaining any required license or other approval from appropriate authorities, (i) to any countries that are subject to any U.S. or local export restrictions (currently including, but not necessarily limited to, Belarus, Cuba, Iran, North Korea, Russia, Syria, the Region of Crimea, Donetsk Peoples Republic Region and Luhansk Peoples Republic Region);
(ii) to any end-user who you know or have reason to know will utilize them in the design, development or production of nuclear, chemical or biological weapons, missiles, rocket systems, unmanned air vehicles capable of a maximum range of at least 300 kilometers, regardless of payload, or intended for military end-use, or any weapons of mass destruction;
(iii) to any end-user who has been prohibited from participating in the U.S. or local export transactions by any governing authority;
or (iv) to any known military or military-intelligence end-user or for any known military or military-intelligence end-use in accordance with U.S. trade compliance laws and regulations.
15.5. Government Rights.
The Software, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R.
2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R.
12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense;
(ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement;
and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense.
In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R.
52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing.
15.6. Notices.
Please direct your legal notices or other correspondence to legalnotices@nvidia.com with a copy mailed to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department.
If NVIDIA needs to contact you, you consent to receive the notices by email and agree that such notices will satisfy any legal communication requirements.
15.7. Severability.
If a court of competent jurisdiction rules that a provision of this Agreement is unenforceable, that provision will be deemed modified to the extent necessary to make it enforceable and the remainder of this Agreement will continue in full force and effect.
15.8. Amendment.
Any amendment to this Agreement must be in writing and signed by authorized representatives of both parties.
15.9. Construction.
The headings in the Agreement are included solely for convenience and are not intended to affect the meaning or interpretation of the Agreement.
As required by the context of the Agreement, the singular of a term includes the plural and vice versa.
15.10. Force Majeure.
Neither party will be liable during any period where an event or circumstance prevents or delays that party from performing its obligations under this Agreement and that event or circumstance: (i) is not within the reasonable control of that party and is not the result of that partys negligence, and (ii) cannot be overcome or avoided by that party using reasonably diligent efforts.
15.11. Entire Agreement.
Regarding the subject matter of this Agreement, the parties agree that (a) this Agreement constitutes the entire and exclusive agreement between the parties and supersedes all prior and contemporaneous communications and (b) any additional or different terms or conditions, whether contained in purchase orders, order acknowledgments, invoices or otherwise, will not be binding and are null and void.
(v. May 8, 2025)

View File

@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
# Local module imports
from .dsl import *
from .runtime import *
from ._mlir_helpers import lru_cache_ir
from .env_manager import get_str_env_var, detect_gpu_arch

View File

@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides MLIR Dialect helper functions
"""
from . import arith
from .lru_cache_ir import lru_cache_ir
__all__ = ["arith", "lru_cache_ir"]
try:
from . import gpu
__all__.extend(["gpu"])
except ImportError:
pass

View File

@@ -0,0 +1,691 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides MLIR Arith Dialect helper functions
"""
import array
import numpy as np
from ..common import *
from ..._mlir import ir # type: ignore
from ..._mlir.extras import types as T # type: ignore
from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore
from .lru_cache_ir import lru_cache_ir
# =============================================================================
# Arith Dialect Helper functions
# =============================================================================
def recast_type(src_type, res_elem_type) -> ir.Type:
if isinstance(src_type, T.VectorType):
if src_type.scalable:
res_type = T.vector(
*src_type.shape,
res_elem_type,
scalable=src_type.scalable,
scalable_dims=src_type.scalable_dims,
)
else:
res_type = T.vector(*src_type.shape, res_elem_type)
elif isinstance(src_type, T.RankedTensorType):
res_type = T.RankedTensorType.get(
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
)
elif isinstance(src_type, T.UnrankedTensorType):
res_type = T.UnrankedTensorType.get(element_type=res_elem_type)
elif isinstance(src_type, T.MemRefType):
res_type = T.MemRefType.get(
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
)
else:
res_type = res_elem_type
return res_type
def is_scalar(ty) -> bool:
return not isinstance(
ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType)
)
def element_type(ty) -> ir.Type:
if not is_scalar(ty):
return ty.element_type
else:
return ty
def is_narrow_precision(ty) -> bool:
narrow_types = {
T.f8E8M0FNU(),
T.f8E4M3FN(),
T.f8E4M3(),
T.f8E5M2(),
T.f8E4M3B11FNUZ(),
T.f4E2M1FN(),
T.f6E3M2FN(),
T.f6E2M3FN(),
}
return ty in narrow_types
def is_float_type(ty) -> bool:
return (
arith._is_float_type(ty)
# TODO-upstream: prediction is not correct. Patch here and fix in upstream later
or is_narrow_precision(ty)
or ty in (T.bf16(), T.tf32())
)
def truncf_to_narrow(res_ty, src, loc, ip):
res_elem_ty = element_type(res_ty)
if res_elem_ty == T.f8E8M0FNU():
rnd = nvgpu.RoundingMode.RP
else:
rnd = nvgpu.RoundingMode.RN
return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip)
def extf_from_narrow(res_ty, src, loc, ip):
src_elem_ty = element_type(src.type)
# When source type is E8M0, temporary element type has to be bf16
tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16()
tmp_ty = recast_type(src.type, tmp_elem_ty)
# narrow -> bf16/f16 -> target type
tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip)
return arith.extf(res_ty, tmp, loc=loc, ip=ip)
def bitcast(src, res_elem_type, *, loc=None, ip=None):
res_type = recast_type(src.type, res_elem_type)
return arith.bitcast(res_type, src, loc=loc, ip=ip)
def cvtf(src, res_elem_type, *, loc=None, ip=None):
src_elem_type = element_type(src.type)
if res_elem_type == src_elem_type:
return src
res_type = recast_type(src.type, res_elem_type)
# Treat TF32 as F32 and use i32 as intermediate data
# TODO-upstream: update arith to support tf32 <-> f32 conversion
if src_elem_type == T.tf32():
# tf32 -> i32
tmp_type = recast_type(src.type, T.i32())
src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip)
# i32 -> f32
src = bitcast(src, T.f32(), loc=loc, ip=ip)
# f32 -> X with `cvtf` recursively
return cvtf(src, res_elem_type, loc=loc, ip=ip)
if res_elem_type == T.tf32():
# X -> f32 with `cvtf`` recursively
tmp = cvtf(src, T.f32(), loc=loc, ip=ip)
# f32 -> i32
tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip)
# i32 -> tf32
return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip)
if res_elem_type.width > src_elem_type.width:
if is_narrow_precision(src_elem_type):
return extf_from_narrow(res_type, src, loc, ip)
else:
return arith.extf(res_type, src, loc=loc, ip=ip)
else:
tmp_mlir_type = recast_type(src.type, T.f32())
# f16 -- extf -> f32 -- truncf -> bf16
# TODO-upstream: update arith to support bf16 <-> f16 conversion?
if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or (
src_elem_type == T.bf16() and res_elem_type == T.f16()
):
tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip)
return arith.truncf(res_type, tmp, loc=loc, ip=ip)
# {f8, f6, f4} -> f16, f32, ...
elif is_narrow_precision(res_elem_type):
return truncf_to_narrow(res_type, src, loc, ip)
else:
return arith.truncf(res_type, src, loc=loc, ip=ip)
def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
res_type = recast_type(src.type, res_elem_type)
# TODO-upstream: update arith to support this kind of conversion
if element_type(src.type) in (T.tf32(), T.bf16()):
src = cvtf(src, T.f32(), loc=loc, ip=ip)
if signed:
return arith.fptosi(res_type, src, loc=loc, ip=ip)
else:
return arith.fptoui(res_type, src, loc=loc, ip=ip)
def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
res_type = recast_type(src.type, res_elem_type)
orig_res_type = res_type
# TODO-upstream: update arith to support this kind of conversion
if res_elem_type in (T.tf32(), T.bf16()):
res_type = recast_type(src.type, T.f32())
if signed and element_type(src.type).width > 1:
res = arith.sitofp(res_type, src, loc=loc, ip=ip)
else:
res = arith.uitofp(res_type, src, loc=loc, ip=ip)
if orig_res_type == res_type:
return res
return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip)
def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
src_signed = a.signed
dst_signed = dst_elem_type.signed
src_width = element_type(a.type).width
dst_width = dst_elem_type.width
dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type)
if dst_width == src_width:
return a
elif src_signed and not dst_signed:
# Signed -> Unsigned
if dst_width > src_width:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
else:
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
elif src_signed == dst_signed:
# Same signedness
if dst_width > src_width:
if src_signed and src_width > 1:
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
else:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
else:
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
else:
# Unsigned -> Signed
if dst_width > src_width:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
else:
# For truncation from unsigned to signed, we need to handle overflow
# First truncate to the target width
trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
# Then reinterpret as signed
if dst_signed:
return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip)
return trunc
# =============================================================================
# Arith Ops Emitter Helpers
# - assuming type of lhs and rhs match each other
# - op name matches python module operator
# =============================================================================
def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None):
"""
This function provides simplified interface to upstream op builder
arith.truncf(T.vector(shape, new_type), src)
is simplified as because it's element-wise op which can't change shape
arith.truncf(new_type, src)
"""
if isinstance(src, ir.Value):
src_ty = src.type
else:
src_ty = type(src).mlir_type
src = src.ir_value()
src_elem_ty = element_type(src_ty)
if src_elem_ty == res_elem_ty:
return src
elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty):
# float-to-float
return cvtf(src, res_elem_ty, loc=loc, ip=ip)
elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type(
res_elem_ty
):
if src_elem_ty.width >= res_elem_ty.width:
cast_op = arith.trunci
else:
if is_signed:
cast_op = arith.extsi
else:
cast_op = arith.extui
res_ty = recast_type(src_ty, res_elem_ty)
return cast_op(res_ty, src, loc=loc, ip=ip)
elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty):
return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip)
elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty):
return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip)
else:
raise DSLRuntimeError(
f"cast from {src_elem_ty} to {res_elem_ty} is not supported"
)
@lru_cache_ir()
def const(value, ty=None, *, loc=None, ip=None):
"""
Generates dynamic expression for constant values.
"""
from ..typing import Numeric, NumericMeta
from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type
if isinstance(value, Numeric):
value = value.value
# Early return
if is_dynamic_expression(value) and (
value.type.isinstance(value.type) or T.bool().isinstance(value.type)
):
return value
# Assume type
if ty is None:
if isinstance(value, float):
ty = T.f32()
elif isinstance(value, bool):
ty = T.bool()
elif isinstance(value, int):
ty = T.i32()
elif isinstance(value, np.ndarray):
ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype))
value = array.array(value.dtype.kind, value.flatten().tolist())
else:
raise DSLNotImplemented(f"{type(value)} is not supported")
elif isinstance(ty, NumericMeta):
ty = ty.mlir_type
elif isinstance(ty, ir.Type):
if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty):
elem_ty = ty.element_type
if isinstance(elem_ty, ir.IntegerType):
attr = ir.IntegerAttr.get(elem_ty, value)
else:
attr = ir.FloatAttr.get(elem_ty, value)
value = ir.DenseElementsAttr.get_splat(ty, attr)
elif arith._is_float_type(ty) and isinstance(value, (bool, int)):
value = float(value)
elif arith._is_integer_like_type(ty) and isinstance(value, float):
value = int(value)
else:
raise DSLNotImplemented(f"type {ty} is not supported")
return arith.constant(ty, value, loc=loc, ip=ip)
def _dispatch_to_rhs_r_op(op):
"""Decorator that dispatches to the right-hand-side's reverse operation.
If the other operand is not an ArithValue or is a subclass (more specific)
of ArithValue, this allows proper method resolution for binary operations.
"""
def wrapper(self, other, **kwargs):
if not isinstance(other, ArithValue):
if not isinstance(other, (int, float, bool)):
# allows to call other.__rmul__
return NotImplemented
return op(self, other, **kwargs)
return wrapper
def _binary_op(op):
"""
Decorator to check if the 'other' argument is an ArithValue.
If not, returns NotImplemented.
"""
def wrapper(self, other, **kwargs):
# When reach this point, `self` must be cast to base `ArithValue` type
if isinstance(other, (int, float, bool)):
other = const(other, self.type).with_signedness(self.signed)
# Call the original function
# If sub-class doesn't implement overloaded arithmetic, cast to base class
return op(self, other, **kwargs)
return wrapper
# Operator overloading
@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid)
@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid)
@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid)
@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid)
@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid)
@ir.register_value_caster(ir.Float8E5M2Type.static_typeid)
@ir.register_value_caster(ir.Float8E4M3Type.static_typeid)
@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid)
@ir.register_value_caster(ir.BF16Type.static_typeid)
@ir.register_value_caster(ir.F16Type.static_typeid)
@ir.register_value_caster(ir.FloatTF32Type.static_typeid)
@ir.register_value_caster(ir.F32Type.static_typeid)
@ir.register_value_caster(ir.F64Type.static_typeid)
@ir.register_value_caster(ir.IntegerType.static_typeid)
@ir.register_value_caster(ir.VectorType.static_typeid)
@ir.register_value_caster(ir.RankedTensorType.static_typeid)
class ArithValue(ir.Value):
"""Overloads operators for MLIR's Arith dialects binary operations."""
def __init__(self, v, signed: Union[bool, None] = None):
if isinstance(v, int):
v = arith.constant(self.type, v)
super().__init__(v)
elem_ty = element_type(self.type)
self.is_float = arith._is_float_type(elem_ty)
# arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
self.signed = signed and elem_ty.width > 1
def with_signedness(self, signed: Union[bool, None]):
return type(self)(self, signed)
def __neg__(self, *, loc=None, ip=None):
if self.type == T.bool():
raise TypeError(
"Negation, the operator `-` is not supported for boolean type"
)
if self.is_float:
return arith.negf(self, loc=loc, ip=ip)
else:
c0 = arith.constant(self.type, 0, loc=loc, ip=ip)
return arith.subi(c0, self, loc=loc, ip=ip)
@_binary_op
def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float and other.is_float:
return math.powf(self, other, loc=loc, ip=ip)
elif self.is_float and not other.is_float:
return math.fpowi(self, other, loc=loc, ip=ip)
elif not self.is_float and other.is_float:
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
rhs = cvtf(other, T.f32(), loc=loc, ip=ip)
return math.powf(lhs, rhs, loc=loc, ip=ip)
elif not self.is_float and not other.is_float:
return math.ipowi(self, other, loc=loc, ip=ip)
else:
raise DSLNotImplemented(f"Unsupported '{self} ** {other}'")
@_binary_op
def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__pow__(self, loc=loc, ip=ip)
# arith operators
@_dispatch_to_rhs_r_op
@_binary_op
def __add__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.addf(self, other, loc=loc, ip=ip)
else:
return arith.addi(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.subf(self, other, loc=loc, ip=ip)
else:
return arith.subi(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.mulf(self, other, loc=loc, ip=ip)
else:
return arith.muli(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.divf(self, other, loc=loc, ip=ip)
else:
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip)
return arith.divf(lhs, rhs, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
q = arith.divf(self, other, loc=loc, ip=ip)
return math.floor(q, loc=loc, ip=ip)
elif self.signed:
return arith.floordivsi(self, other, loc=loc, ip=ip)
else:
return arith.divui(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.remf(self, other, loc=loc, ip=ip)
elif self.signed:
return arith.remsi(self, other, loc=loc, ip=ip)
else:
return arith.remui(self, other, loc=loc, ip=ip)
@_binary_op
def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__add__(self, loc=loc, ip=ip)
@_binary_op
def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__sub__(self, loc=loc, ip=ip)
@_binary_op
def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__mul__(self, loc=loc, ip=ip)
@_binary_op
def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__truediv__(self, loc=loc, ip=ip)
@_binary_op
def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__floordiv__(self, loc=loc, ip=ip)
@_binary_op
def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__mod__(self, loc=loc, ip=ip)
# Comparison operators (comparison doesn't have right-hand-side variants)
@_dispatch_to_rhs_r_op
@_binary_op
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
elif self.signed:
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
elif self.signed:
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
# In Python, bool(float("nan")) is True, so use unordered comparison here
return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
elif self.signed:
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
elif self.signed:
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
# Unary operators
def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
return arith.xori(self, arith.const(self.type, -1))
# Bitwise operations
@_dispatch_to_rhs_r_op
@_binary_op
def __and__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.andi(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __or__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.ori(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.xori(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.signed:
return arith.shrsi(self, other, loc=loc, ip=ip)
else:
return arith.shrui(self, other, loc=loc, ip=ip)
@_dispatch_to_rhs_r_op
@_binary_op
def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.shli(self, other, loc=loc, ip=ip)
@_binary_op
def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.andi(other, self, loc=loc, ip=ip)
@_binary_op
def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.ori(other, self, loc=loc, ip=ip)
@_binary_op
def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue":
return arith.xori(other, self, loc=loc, ip=ip)
@_binary_op
def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__rshift__(self, loc=loc, ip=ip)
@_binary_op
def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
return other.__lshift__(self, loc=loc, ip=ip)
def __hash__(self):
return super().__hash__()
def __str__(self):
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
def __repr__(self):
return self.__str__()
def _min(lhs, rhs, *, loc=None, ip=None):
"""
This function provides a unified interface for building arith min
Assuming the operands have the same type
"""
from ..dsl import is_dynamic_expression
if not is_dynamic_expression(lhs):
if not is_dynamic_expression(rhs):
return min(lhs, rhs)
else:
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
else:
if not is_dynamic_expression(rhs):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.minui(lhs, rhs, loc=loc, ip=ip)
else:
return arith.minimumf(lhs, rhs, loc=loc, ip=ip)
def _max(lhs, rhs, *, loc=None, ip=None):
"""
This function provides a unified interface for building arith max
Assuming the operands have the same type
"""
from ..dsl import is_dynamic_expression
if not is_dynamic_expression(lhs):
if not is_dynamic_expression(rhs):
return max(lhs, rhs)
else:
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
else:
if not is_dynamic_expression(rhs):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.maxui(lhs, rhs, loc=loc, ip=ip)
else:
return arith.maximumf(lhs, rhs, loc=loc, ip=ip)

View File

@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides MLIR GPU Dialect helper functions
"""
from ..._mlir import ir
from ..._mlir.dialects import gpu, arith, scf
from ..._mlir.extras import types as T
from ..common import *
# =============================================================================
# GPU Dialect Helper functions
# =============================================================================
def create_async_token():
token_ty = gpu.AsyncTokenType.get()
token = gpu.wait(token_ty, [])
return token
def printf(fmt, *args, threadNumber=-1):
"""Generate gpu.printf OP predicated on threadNumber"""
type_formats = []
for arg in args:
ty_format = None
if ir.IndexType.isinstance(arg.type):
ty_format = "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
if width == 64:
ty_format = "%llu"
elif width == 32:
ty_format = "%d"
elif width == 1:
ty_format = "%i"
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ty_format is None:
raise DSLNotImplemented(arg.type)
type_formats.append(ty_format)
if threadNumber == -1:
gpu.printf(fmt.format(*type_formats) + "\n", args)
if threadNumber != -1:
tidx = gpu.thread_id(gpu.Dimension.x)
predicate = arith.cmpi(
arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber)
)
if_op = scf.IfOp(predicate)
with ir.InsertionPoint(if_op.then_block):
gpu.printf(fmt.format(*type_formats) + "\n", args)
scf.yield_([])

View File

@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides @lru_cache_ir
It extends functools.lru_cache with IR Context awareness.
Example usage:
from cutlass import ir
from lru_cache_ir import lru_cache_ir
@lru_cache_ir(ir, maxsize=128, typed=False)
def make_layout(...):
...
"""
from functools import lru_cache, wraps
from ..._mlir import ir # type: ignore
def get_ir_context(func):
"""
Return the context for given func called under ir.
Currently the context includes MLIRContext and InsertionPoint.
"""
try:
if ir:
return (ir.Context.current, ir.InsertionPoint.current)
else:
return None
except ValueError:
return None
def lru_cache_ir(maxsize=128, typed=True):
"""
Applies an LRU cache to a given function, with awareness of IR context.
Usage is similar to functools.lru_cache while taking `ir` as required argument.
:param ir: The IR object from which to derive the context by `get_ir_context`
:param maxsize: Max cache size, same as functools.lru_cache
:param typed: Whether params are type-sensitive, default to True as IR is type-sensitive
"""
def decorator(func):
# Use functools.lru_cache with a custom wrapper to control the key generation
@lru_cache(maxsize=maxsize, typed=typed)
def cached_func(context, *args, **kwargs):
return func(*args, **kwargs)
@wraps(func)
def wrapper(*args, **kwargs):
try:
# Call the cached function with the context
return cached_func(get_ir_context(func), *args, **kwargs)
except (RuntimeError, TypeError):
return func(*args, **kwargs)
# Expose cache-related methods for introspection
wrapper.cache_clear = cached_func.cache_clear
wrapper.cache_info = cached_func.cache_info
return wrapper
return decorator

View File

@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides MLIR's OP helper functions
"""
import inspect
from functools import wraps
from ..._mlir import ir
def dsl_user_op(opFunc):
@wraps(opFunc)
def wrapper(*args, **kwargs):
loc = kwargs.pop("loc", None)
if loc is None:
frame = inspect.currentframe().f_back
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
res_or_list = opFunc(*args, **kwargs, loc=loc)
return res_or_list
return wrapper

View File

@@ -0,0 +1,584 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""
from typing import Callable, Iterator, Optional, overload
from .utils.logger import log
from .common import *
from ._mlir_helpers.arith import ArithValue
class Executor:
"""
The Executor class handles dynamic and compile-time (constexpr) execution
of "for" loops and "if-else-elif" statements.
Methods:
set_functions: Assigns the functions for checking loop bounds and
conditional evaluation.
for_dynamic: Generates MLIR for OP
for_constexpr: Executes a for loop at JIT compile-time
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
if_dynamic: Generates MLIR if OP
if_constexpr: Executes a if at JIT compile-time by python interpreter
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
"""
def __init__(self):
self._is_dynamic_expression = None
self._loop_execute_range_dynamic = None
self._if_dynamic = None
self._while_dynamic = None
def set_functions(
self,
is_dynamic_expression: Callable,
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
while_dynamic: Callable,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
self._if_dynamic = if_dynamic
self._while_dynamic = while_dynamic
@staticmethod
def convert_to_list(x):
"""This function is used to convert x to a list.
If x is None, return an empty list.
If x is not a list, return a list containing x.
Otherwise, return x itself.
"""
if x is None:
return []
if not isinstance(x, list):
return [x]
return x
@staticmethod
def converge_ret_val(res):
"""This function is used to converge res (the return value) of the function.
If res is None, return None.
If res is a list and has only one element, return the element.
Otherwise, return res itself.
"""
if res is None:
return res
elif isinstance(res, list) and len(res) == 1:
return res[0]
return res
def for_dynamic(
self,
func: Callable,
start,
stop,
step,
used_args: list,
iter_args: list,
iter_arg_names: list,
unroll=bool,
unroll_full=int,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
return self._loop_execute_range_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
@staticmethod
def for_constexpr(
func: Callable,
start: int,
stop: int,
step: int,
used_args: list,
iter_args: list,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
log().debug("i [%s] iter_args [%s]", i, iter_args)
loop_results = func(i, *used_args, *loop_results)
log().debug("loop_results [%s]", loop_results)
if loop_results is None:
loop_results = []
if not isinstance(loop_results, list):
loop_results = [loop_results]
log().debug("done loop_results [%s]", loop_results)
return Executor.converge_ret_val(loop_results)
def for_execute(
self,
func,
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
is_range_constexpr=None,
):
assert (
self._loop_execute_range_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
any_dynamic_expression = (
self._is_dynamic_expression(start)
or self._is_dynamic_expression(stop)
or self._is_dynamic_expression(step)
)
if is_range_constexpr is None:
if not any_dynamic_expression:
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
else:
return self.for_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
# Ensure bounds are compile-time constants for constexpr execution
if is_range_constexpr:
if any_dynamic_expression:
raise DSLRuntimeError(
"Loop bounds must be constexpr (compile-time constants)"
)
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
# MLIR generation
return self.for_dynamic(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
)
def if_dynamic(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
@staticmethod
def if_constexpr(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
):
if pred:
log().debug(" running then block [%s]", yield_args)
res = then_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
elif else_block is not None:
log().debug("running else [%s]", yield_args)
res = else_block(*used_args, *yield_args)
log().debug("result [%s]", res)
return Executor.converge_ret_val(res)
def if_execute(
self,
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
if_constexpr=None,
):
assert (
self._if_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_if_constexpr = not self._is_dynamic_expression(pred)
if if_constexpr is None:
if is_if_constexpr:
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
else:
return self.if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
# Ensure bounds are compile-time constants for constexpr execution
if if_constexpr:
if not is_if_constexpr:
raise DSLRuntimeError(
"If predicate must be constexpr (compile-time constants)"
)
return self.if_constexpr(
pred, then_block, else_block, used_args, yield_args
)
# MLIR generation
return self.if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
)
def while_dynamic(
self,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
@staticmethod
def while_constexpr(
while_before_block,
while_after_block,
used_args=[],
yield_args=[],
):
log().debug(
"while_constexpr begin %s", while_before_block.__qualname__
)
cond, loop_results = while_before_block(*used_args, *yield_args)
while cond:
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_after [%s], [%s]",
used_args,
loop_results,
)
loop_results = while_after_block(*used_args, *loop_results)
log().debug(
"while after [%s]", loop_results
)
loop_results = Executor.convert_to_list(loop_results)
log().debug(
"calling while_before [%s], [%s]",
used_args,
loop_results,
)
cond, loop_results = while_before_block(*used_args, *loop_results)
log().debug(
"while_before cond, results [%s], [%s]",
cond,
loop_results,
)
log().debug(
"while_constexpr results %s", loop_results
)
return Executor.converge_ret_val(loop_results)
def while_execute(
self,
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
while_constexpr=None,
):
assert (
self._while_dynamic and self._is_dynamic_expression
), "Functions must be set before execution."
is_while_constexpr = not self._is_dynamic_expression(pred)
# Ensure bounds are compile-time constants for constexpr execution
if while_constexpr:
if not is_while_constexpr:
raise DSLRuntimeError(
"While predicate must be constexpr (compile-time constants)"
)
return self.while_constexpr(
while_before_block, while_after_block, used_args, yield_args
)
# MLIR generation
return self.while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
)
# =============================================================================
# Decorator
# =============================================================================
executor = Executor()
def loop_selector(
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
unroll=-1,
unroll_full=False,
constexpr=None,
):
log().info(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
start,
stop,
step,
used_args,
iter_args,
unroll,
unroll_full,
constexpr,
)
from .typing import Integer, Numeric
def _maybe_upcast(value):
if isinstance(value, Integer):
value = value.ir_value()
return value
start = _maybe_upcast(start)
stop = _maybe_upcast(stop)
step = _maybe_upcast(step)
def ir_loop(func):
return executor.for_execute(
func,
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
unroll,
unroll_full,
constexpr,
)
return ir_loop
def if_selector(pred, used_args=[], yield_args=[]):
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
# Handle Numeric types here?
from .typing import Numeric
if isinstance(pred, Numeric):
pred = pred.value
def ir_loop(func):
return func(pred, *used_args, *yield_args)
return ir_loop
def while_selector(pred, used_args=[], yield_args=[]):
def ir_while_loop(func):
return func(pred, *used_args, *yield_args)
return ir_while_loop
def while_executor(
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.while_execute(
pred,
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
constexpr,
)
def if_executor(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
constexpr=None,
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
)
# =============================================================================
# Range
# =============================================================================
class range_dynamic:
@overload
def __new__(cls, stop, unroll=0, unroll_full=False):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
pass
def __new__(cls, *args, **kwargs):
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
class range_constexpr:
def __init__(self, *args):
if len(args) == 1:
self.start = 0
self.stop = args[0]
self.step = 1
elif len(args) == 2:
self.start, self.stop = args
self.step = 1
elif len(args) == 3:
self.start, self.stop, self.step = args
else:
raise DSLRuntimeError(
"range_constexpr supports up to 3 arguments (start, stop, step)"
)
# Ensure the arguments are compile-time constants (if required)
for arg_name, arg_value in [
("step", self.step),
("start", self.start),
("stop", self.stop),
]:
if executor._is_dynamic_expression(arg_value):
raise DSLRuntimeError(
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
f"will handle them during runtime. ",
suggestion="Use `range` instead of `range_constexpr`.",
)
def __iter__(self) -> Iterator[int]:
current = self.start
while current < self.stop:
yield current
current += self.step
# =============================================================================
# If expressions
# =============================================================================
def const_expr(expression):
if executor._is_dynamic_expression(expression):
raise DSLRuntimeError(
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
context={
"const_expr": "Accepts only constexpr (compile-time constant)",
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
},
)
return expression
def dynamic_expr(expression):
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
# =============================================================================
# Assertion & casting
# =============================================================================
def assert_executor(test, msg=None):
from .typing import Numeric
fail = False
# Implicit convert dynamic expression to bool is not allowed
# So here explicitly do a None check
if test is not None and executor._is_dynamic_expression(test):
if isinstance(test, Numeric):
try:
test = test.to(bool)
except:
fail = True
else:
fail = True
if not fail:
assert test, msg
else:
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please replace with runtime assert."
)
def bool_cast(value):
if executor._is_dynamic_expression(value):
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please explicitly convert to boolean with expressions like comparision."
)
return bool(value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides jit cache load/dump helper functions
"""
import os
import uuid
import random
import tempfile
import pwd
import time
from pathlib import Path
import hashlib
from .utils.logger import log
from .jit_executor import JitExecutor
from .._mlir import ir
# =============================================================================
# Jit Cache Helper functions
# =============================================================================
def get_current_user():
# Try to get the user from the environment variable first
user = os.getenv("USER") or os.getenv("USERNAME")
if not user:
# Fallback for Unix-like systems
user = pwd.getpwuid(os.getuid()).pw_name
return user
try:
default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
except Exception as e:
# If all else fails, provide a default fallback path
default_generated_ir_path = "/tmp/cutlass_python_cache/"
print(f"Could not determine user, using default path. Error: {e}")
def load_ir(file, asBytecode=False):
"""Load generated IR from a file."""
assert "mlir" in file
func_name = file.split(".mlir")[0].split("dsl_")[-1]
with ir.Context() as ctx:
with open(file, "rb" if asBytecode else "r") as f:
module = ir.Module.parse(f.read())
return func_name, module
def make_unique_filename(fpath: Path, new_ext: str = None) -> Path:
"""Generate a unique filename with an optional new extension."""
random_part = random.randint(0, 999999)
timestamp = time.time()
hash_input = f"{fpath}_{timestamp}_{random_part}".encode()
hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability
stem_with_hash = f"{fpath.stem}_{hash_code}"
return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix)
def save_ir(
dsl_name: str,
module: object,
fname: str,
isTemp: bool = False,
asBytecode: bool = False,
) -> str:
"""Save generated IR to a file."""
initial_name = f"{dsl_name.lower()}_{fname}.mlir"
save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd())
save_fname = save_path / initial_name
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
pid = os.getpid()
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}")
# If the process exits abnormally, may leave a temporary folder. Needs to be removed manually.
os.makedirs(temp_dir, exist_ok=False)
temp_fname = os.path.join(temp_dir, initial_name)
if asBytecode:
with open(temp_fname, "wb") as f:
module.operation.write_bytecode(f)
else:
with open(temp_fname, "w") as f:
print(module, file=f)
# os.replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_fname, save_fname)
os.removedirs(temp_dir)
log().debug("Generated IR saved into %s", save_fname)
return save_fname
def check_func_name(jit_cache, func_name):
if not func_name in jit_cache:
jit_cache[func_name] = JitExecutor(None, None, None, None, None, None)
return jit_cache
def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
"""Load cache from a directory path."""
if not os.path.exists(path):
return dict()
files = os.listdir(path)
jit_cache = dict()
try:
for idx, file in enumerate(files):
if idx >= int(cache_limit):
break
# identify dsl prefix
if not file.startswith(f"{dsl_name.lower()}"):
continue
if ".mlir" in file:
func_name, ir_module = load_ir(
os.path.join(path, file), asBytecode=True
)
jit_cache = check_func_name(jit_cache, func_name)
jit_cache[func_name].ir_module = ir_module
except Exception as e:
print(f"{dsl_name} failed with loading generated IR cache.", e)
jit_cache = dict()
return jit_cache
def dump_cache_to_path(
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
):
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
if not os.path.exists(path):
os.makedirs(path)
original_path = os.getcwd()
try:
os.chdir(path)
for idx, [key, value] in enumerate(jit_cache.items()):
if idx >= int(cache_limit):
break
save_ir(dsl_name, value.ir_module, key, asBytecode=True)
except Exception as e:
print(f"{dsl_name} failed with caching generated IR", e)
finally:
os.chdir(original_path)

View File

@@ -0,0 +1,268 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import os
from typing import Any, Dict, Iterable, Optional, Union
"""
This module provides a Exception classes DSL class for any Dialect.
"""
# Add color codes at the top of the file after imports
class Colors:
"""ANSI color codes for error messages"""
RED = "\033[91m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
GREEN = "\033[92m"
BOLD = "\033[1m"
RESET = "\033[0m"
# =============================================================================
# DSL Exceptions
# =============================================================================
class DSLBaseError(Exception):
"""
Base exception for DSL-related errors.
Provides optional contextual metadata to aid in debugging.
"""
def __init__(
self,
message: str,
line: Optional[int] = None,
snippet: Optional[str] = None,
filename: Optional[str] = None,
error_code: Optional[Union[str, int]] = None,
context: Optional[Union[Dict[str, Any], str]] = None,
suggestion: Optional[str] = None,
cause: Optional[BaseException] = None,
) -> None:
self.message = message
self.line = line
self.filename = filename
self.snippet = snippet
self.error_code = error_code
self.context = context
self.suggestion = suggestion
self.cause = cause
super().__init__(self._format_message())
def _format_message(self):
"""
Formats the complete error message with available metadata.
Override this in subclasses if you want to change formatting logic.
"""
parts = [f"{self.__class__.__name__}: {self.message}"]
if self.error_code is not None:
parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n")
if self.line is not None:
parts.append(f" Line: {self.line}")
if self.filename is not None:
parts.append(f" File: {self.filename}")
if self.snippet:
# Optionally truncate long snippets for readability
parts.append(f" Snippet: \n {self.snippet}")
if self.cause:
parts.append(f" Caused exception: {self.cause}")
if self.context:
if isinstance(self.context, dict):
parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n")
for key, value in self.context.items():
parts.append(f" {key}: {value}")
else:
parts.append(
f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}"
)
if self.suggestion:
parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}")
if isinstance(self.suggestion, (list, tuple)):
for suggestion in self.suggestion:
parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}")
else:
parts.append(f" {self.suggestion}")
return "\n".join(parts)
class DSLRuntimeError(DSLBaseError):
"""
Raised when an error occurs during JIT-time code generation in the DSL.
"""
# Inherits all logic from DSLBaseError; override methods if you need
# specialized behavior or formatting for runtime errors.
pass
def _get_friendly_cuda_error_message(error_code, error_name):
# Avoid circular dependency
from .runtime.cuda import get_device_info
"""Get a user-friendly error message for common CUDA errors."""
# Strip the byte string markers if present
if isinstance(error_name, bytes):
error_name = error_name.decode("utf-8")
elif (
isinstance(error_name, str)
and error_name.startswith("b'")
and error_name.endswith("'")
):
error_name = error_name[2:-1]
# Add target architecture info
target_arch = os.getenv("CUTE_DSL_ARCH", "unknown")
error_messages = {
"CUDA_ERROR_INVALID_SOURCE": (
f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n"
),
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n"
),
"CUDA_ERROR_OUT_OF_MEMORY": (
f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n"
),
"CUDA_ERROR_INVALID_DEVICE": (
f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n"
),
"CUDA_ERROR_NOT_INITIALIZED": (
f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n"
),
"CUDA_ERROR_INVALID_VALUE": (
f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n"
f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}"
),
}
error_suggestions = {
"CUDA_ERROR_INVALID_SOURCE": (
f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture",
f"2. Clear the compilation cache and regenerate the kernel",
f"3. Check CUDA toolkit installation",
),
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
f"Set env CUTE_DSL_ARCH to match your GPU architecture",
),
"CUDA_ERROR_OUT_OF_MEMORY": (
f"1. Reduce batch size",
f"2. Reduce model size",
f"3. Free unused GPU memory",
),
"CUDA_ERROR_INVALID_DEVICE": (
f"1. Check if CUDA device is properly initialized",
f"2. Verify GPU is detected: nvidia-smi",
f"3. Check CUDA_VISIBLE_DEVICES environment variable",
),
"CUDA_ERROR_NOT_INITIALIZED": (
f"1. Check CUDA driver installation",
f"2. call `cuda.cuInit(0)` before any other CUDA operation",
f"3. Run nvidia-smi to confirm GPU status",
),
"CUDA_ERROR_INVALID_VALUE": (
f"1. Your GPU model",
f"2. SM ARCH setting",
f"3. Steps to reproduce",
),
}
message = error_messages.get(
error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}"
)
# Add debug information
debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n"
debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n"
debug_info += (
f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n"
)
try:
# Get GPU information using CUDA Python API
debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n"
gpu_info = get_device_info()
debug_info += gpu_info.pretty_str()
if target_arch and gpu_info.compatible_archs:
debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n"
if target_arch not in gpu_info.compatible_archs:
debug_info += (
f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n"
f"💡 Please use one of SM ARCHs: "
f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n"
)
elif target_arch != gpu_info.sm_arch:
debug_info += (
f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n"
f"• Current: {target_arch}\n"
f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n"
)
else:
debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n"
except Exception as e:
debug_info += (
f"\n{Colors.YELLOW} Could not retrieve GPU info: {str(e)}{Colors.RESET}"
)
return message, debug_info, error_suggestions.get(error_name, "")
class DSLCudaRuntimeError(DSLBaseError):
"""
Raised when an error occurs during CUDA runtime code generation in the DSL.
"""
# Inherits all logic from DSLRuntimeError; override methods if you need
# specialized behavior or formatting for runtime errors.
def __init__(self, error_code, error_name) -> None:
self._error_code = error_code
self._error_name = error_name
message, debug_info, suggestion = _get_friendly_cuda_error_message(
error_code, error_name
)
super().__init__(
message, error_code=error_code, context=debug_info, suggestion=suggestion
)
class DSLAstPreprocessorError(DSLBaseError):
"""
Raised when an error occurs during AST preprocessing or visiting in the DSL.
"""
# Same approach: You could override _format_message if you want
# to emphasize AST node details or anything specific to preprocessing.
pass
class DSLNotImplemented(DSLBaseError):
"""
Raised when a feature of the DSL is not implemented yet.
"""
# Useful for stubs in your DSL that you plan to implement in the future.
pass

View File

@@ -0,0 +1,221 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides a class that compiles generated IR using MLIR's PassManager
and executes it using MLIR's ExecutionEngine.
"""
from typing import Sequence, Optional, Tuple
import os
import sys
import inspect
from .common import DSLRuntimeError
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from .._mlir import ir
# =============================================================================
# Compiler Class
# =============================================================================
class CompilationError(RuntimeError):
"""Custom error class for compilation failures"""
# Add ANSI color codes
RED = "\033[91m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
GREEN = "\033[92m"
BOLD = "\033[1m"
RESET = "\033[0m"
def __init__(
self,
message: str,
nvvm_error: Optional[str] = None,
ir_context: Optional[str] = None,
cuda_toolkit: Optional[str] = None,
arch: Optional[str] = None,
):
self.nvvm_error = nvvm_error
self.ir_context = ir_context
self.cuda_toolkit = cuda_toolkit
self.arch = arch
# Call parent with formatted error to avoid showing class name
super().__init__("") # Empty string to avoid class name
# Store formatted error for str() representation
self._formatted_error = self._format_error()
def __str__(self) -> str:
"""Override string representation to avoid showing class name"""
return self._formatted_error
def __repr__(self) -> str:
"""Override repr representation to avoid showing class name"""
return self._formatted_error
def _format_error(self) -> str:
if not self.nvvm_error:
return str(self.args[0])
return f"""NVVM Compilation Error:
----------------------
{self.BLUE}⚙️ Current Settings:{self.RESET}
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
- Target Architecture: {self.arch}{self.RESET}
IR Context (truncated):
{self.ir_context}
{self.YELLOW}💡 Possible Solutions:{self.RESET}
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
class Compiler:
"""Compiler class for compiling and building MLIR modules."""
def __init__(self, passmanager, execution_engine):
self.passmanager = passmanager
self.execution_engine = execution_engine
def __call__(self, module):
"""Convenience application method."""
self.compile(module)
def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
"""Process error message to extract NVVM error and IR context"""
nvvm_error = None
ir_msg = ""
if "NVVM_ERROR" in error_msg:
# Extract the specific NVVM error
nvvm_error = (
error_msg.split("libNVVM extra log:")[1].strip()
if "libNVVM extra log:" in error_msg
else error_msg
)
# Extract IR context
if "see current operation:" in error_msg:
# Get the IR section
ir_section = error_msg.split("see current operation:")[1].strip()
# Remove duplicate IR section
ir_section = ir_section.split("error: unknown: Failed translating")[
0
].strip()
# Get first few lines and last few lines of the IR
ir_lines = ir_section.split("\n")
if len(ir_lines) > 10:
ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
else:
ir_msg = ir_section
return nvvm_error, ir_msg
def compile(
self,
module,
pipeline: str,
cuda_toolkit: str = "",
arch: str = "",
enable_verifier=False,
):
"""Compiles the module by invoking the pipeline."""
try:
pm = self.passmanager.PassManager.parse(pipeline)
pm.enable_verifier(enable_verifier)
pm.run(module.operation)
except Exception as e:
error_msg = str(e)
nvvm_error, ir_msg = self._process_error(error_msg)
if nvvm_error:
raise CompilationError(
error_msg,
nvvm_error=nvvm_error,
ir_context=ir_msg,
cuda_toolkit=cuda_toolkit,
arch=arch,
) from e
raise e
def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
"""Wraps the module in a JIT execution engine."""
return self.execution_engine.ExecutionEngine(
module, opt_level=opt_level, shared_libs=shared_libs
)
def compile_and_jit(
self,
module,
pipeline: str,
shared_libs: Sequence[str] = (),
opt_level: int = 2,
cuda_toolkit: str = "",
arch: str = "",
):
"""Compiles and jits the module."""
self.compile(
module,
pipeline,
cuda_toolkit,
arch,
)
return self.jit(module, opt_level, shared_libs)
def compile(func, *args, **kwargs):
if func is None:
raise DSLRuntimeError("Function is not set or invalid.")
if not callable(func):
raise DSLRuntimeError("Object is not callable.")
kwargs["compile_only"] = True
kwargs["no_cache"] = True
if inspect.isfunction(func):
# regular function
pass
elif inspect.ismethod(func):
# if it's a method, add the instance to the first argument
args = [func.__self__] + list(args)
func = func.__func__
elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
# If it's a class instance, get the class's __call__ method
args = [func] + list(args)
# Get the actual function from the class definition
func = func.__call__.__func__
else:
raise DSLRuntimeError(
"Invalid function type, only function, method and module are supported, but got",
func,
)
# If it's a wrapped function created by jit decorator, get the original function
if hasattr(func, "__wrapped__"):
func = func.__wrapped__
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
return func._dsl_object._func(fcn_ptr, *args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,303 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides utilities for the environment variables setup.
It provides an EnvironmentVarManager, which reads environment variables for the DSL
and caches them for efficient access.
It also provides utilities to automatically setup a subset of environment variables
based on heuristics.
"""
import os
import sys
import shutil
import glob
from pathlib import Path
from functools import lru_cache
from typing import Any
from ..base_dsl.runtime.cuda import get_compute_capability_major_minor
from .utils.logger import log
IS_WINDOWS = sys.platform == "win32"
CLIB_EXT = ".dll" if IS_WINDOWS else ".so"
# =============================================================================
# Environment Variable Helpers
# =============================================================================
@lru_cache(maxsize=None)
def get_str_env_var(var_name, default_value=None):
value = os.getenv(var_name)
return value if value is not None else default_value
@lru_cache(maxsize=None)
def get_bool_env_var(var_name, default_value=False):
value = get_str_env_var(var_name)
if value is None:
return default_value
return value not in {"False", "0", ""}
@lru_cache(maxsize=None)
def get_int_env_var(var_name, default_value=0):
value = get_str_env_var(var_name)
return int(value) if value and value.isdigit() else default_value
def detect_gpu_arch(prefix):
"""
Attempts to detect the machine's GPU architecture.
Returns:
A string representing the GPU architecture (e.g. "70" for compute capability 7.0),
or a default value(e.g. "sm_100") if the GPU architecture cannot be determined.
"""
arch = (None, None)
try:
arch = get_compute_capability_major_minor()
except Exception as e:
log().info(f"Failed to get CUDA compute capability: {e}")
if arch == (None, None):
# default to sm_100
arch = (10, 0)
major, minor = arch
suffix = ""
if major >= 9 and minor >= 0:
suffix = "a"
elif minor != 0:
# e.g sm_86, belong with sm_80 family
minor = 0
return f"sm_{major}{minor}{suffix}"
def find_libs_in_ancestors(start, target_libs, lib_folder_guesses):
"""
Search ancestor directories for a candidate library folder containing all required libraries.
Starting from the given path, this function traverses up through each parent directory.
For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses)
for files that match the required library extension (CLIB_EXT). Library file names are
canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains
all of the required libraries (as specified in target_libs), the function returns a list of
absolute paths to these library files.
Parameters:
start (str or Path): The starting directory from which to begin the search.
target_libs (iterable of str): A collection of required library names (without the "lib" prefix).
lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries.
Returns:
list[str] or None: A list of resolved paths to the required library files if found; otherwise, None.
"""
# Traverse through all parent directories of the resolved starting path.
for ancestor in Path(start).resolve().parents:
# Iterate over each candidate relative directory path.
for rel_path in lib_folder_guesses:
target_dir = ancestor / rel_path
# Skip if the candidate directory does not exist.
if not target_dir.is_dir():
continue
# Initialize a list to hold the resolved paths of matching library files.
libs_cand = []
# Create a set of the remaining libraries we need to find.
remaining_libs = set(target_libs)
# Iterate over all items in the candidate directory.
for p in target_dir.iterdir():
# Consider only files with the expected library extension.
if p.suffix == CLIB_EXT:
# Canonicalize the library name by removing the "lib" prefix.
lib_name = p.stem.removeprefix("lib")
# If this library is required, add its resolved path and mark it as found.
if lib_name in remaining_libs:
libs_cand.append(str(p.resolve()))
remaining_libs.remove(lib_name)
# If all required libraries have been found, return the list of library paths.
if len(remaining_libs) == 0:
return libs_cand
# Return None if no candidate directory contains all required libraries.
return None
def _find_cuda_home():
"""Find the CUDA installation path using a series of heuristic methods.
Methods below are checked in order, and the function returns on first match:
1. Checking the environment variables CUDA_HOME and CUDA_PATH.
2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda.
3. Scanning common installation directories based on the operating system.
- On Windows systems (when IS_WINDOWS is True), it searches in:
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
- On Unix-like systems, it searches in:
/usr/local/cuda*
Returns:
Optional[str]: The absolute CUDA installation path if found; otherwise, None.
Note:
The variable IS_WINDOWS is defined in the module scope.
"""
# Guess #1
cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
if IS_WINDOWS:
glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
else:
glob_pat = "/usr/local/cuda*"
cuda_homes = glob.glob(glob_pat)
if len(cuda_homes) == 0:
cuda_home = ""
else:
cuda_home = cuda_homes[0]
if not os.path.exists(cuda_home):
cuda_home = None
return cuda_home
def get_cuda_toolkit_path():
"""
Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if
set. Otherwise, attempts to discover a valid CUDA toolkit location and
return. If not found, return None.
"""
# Check if the environment variable is already set, if so, return it immediately.
try:
cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH")
if cuda_toolkit_path_existing:
return cuda_toolkit_path_existing
found_cuda_home = _find_cuda_home()
if found_cuda_home:
return found_cuda_home
except Exception as e:
log().info("default_env: exception on get_cuda_toolkit_path", e)
return None
def get_prefix_dsl_libs(prefix: str):
"""
Returns get_str_env_var('{prefix}_LIBS') if set.
Otherwise, attempts to discover libs based on heuristics and return
If not found, return None.
"""
# Check if the environment variable is already set, if so, return it immediately.
try:
prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS")
if prefix_libs_existing:
return prefix_libs_existing
def get_libs_cand(start):
target_libs = {
"mlir_c_runner_utils",
"mlir_runner_utils",
"mlir_cuda_runtime",
}
lib_folder_guesses = [
"lib",
]
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
if libs_cand:
dsl_libs = ":".join(libs_cand)
return dsl_libs
return None
# find from install folder
dsl_libs = get_libs_cand(__file__)
if not dsl_libs:
# try to find from build folder structure
dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve())
return dsl_libs
except Exception as e:
log().info(f"default_env: exception on get_prefix_dsl_libs", e)
return None
class EnvironmentVarManager:
"""Manages environment variables for configuration options.
Printing options:
- [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False)
- [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False)
- [DSL_NAME]_PRINT_IR: Print generated IR (default: False)
- [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
File options:
- [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
- [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False)
Other options:
- [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1).
- [DSL_NAME]_DRYRUN: Generates IR only (default: False)
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
- [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
- [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False)
"""
def __init__(self, prefix="DSL"):
self.prefix = prefix # change if needed
# Printing options
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
self.print_after_preprocessor = get_bool_env_var(
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
)
self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False)
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
# File options
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
# Other options
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
self.warnings_as_errors = get_bool_env_var(
f"{prefix}_WARNINGS_AS_ERRORS", False
)
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
self.disable_file_caching = get_bool_env_var(
f"{prefix}_DISABLE_FILE_CACHING", False
)
self.file_caching_capacity = get_int_env_var(
f"{prefix}_FILE_CACHING_CAPACITY", 1000
)
self.generate_source_location = not get_bool_env_var(
f"{prefix}_NO_SOURCE_LOCATION", False
)
# set cuda
self.cuda_toolkit = get_cuda_toolkit_path()
# set mlir shared libraries
self.shared_libs = get_prefix_dsl_libs(prefix)

View File

@@ -0,0 +1,301 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides jit executor related classes
"""
import io
import inspect
import ctypes
import numpy as np
from typing import get_origin
# Local modules imports
from .utils.timer import timer
from .utils.logger import log
from .common import DSLRuntimeError
from .runtime import cuda as cuda_helpers
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
from .typing import get_c_pointers
from . import typing as t
# MLIR modules imports
from .._mlir import ir
class CudaSingleModule:
def __init__(self, cuda_module, kernel_ptr):
self.cuda_module = cuda_module
self.kernel_ptr = kernel_ptr
class CudaModules:
def __init__(self, modules, args):
# list of CudaSingleModule
self.modules = modules
# extra kernel ptr arguments for launch
self.args = args
class JitExecutor:
def __init__(
self,
dsl,
engine,
capi_func,
ir_module,
args_spec,
function_name,
cuda_modules: CudaModules = None,
jit_time_profiling=False,
):
self.dsl = dsl
self.engine = engine
self.capi_func = capi_func
self.ir_module = ir_module
self.args_spec = args_spec
self.function_name = function_name
if args_spec is not None:
self.args_spec = self.filter_runtime_arg_spec(args_spec)
# cuda kernels
self.cuda_modules = cuda_modules
self.jit_time_profiling = jit_time_profiling
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
runtime_args = []
runtime_annotations = {}
runtime_defaults = []
# Calculate the offset where defaults start in the original args
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
# Keep runtime arguments
runtime_args.append(arg_name)
if arg_name in arg_spec.annotations:
runtime_annotations[arg_name] = arg_type
# Keep corresponding default if it exists
if i >= defaults_start_idx:
default_idx = i - defaults_start_idx
runtime_defaults.append(arg_spec.defaults[default_idx])
# Filter kwonlyargs and their defaults
runtime_kwonlyargs = []
runtime_kwonlydefaults = {}
if arg_spec.kwonlyargs:
for kwarg in arg_spec.kwonlyargs:
arg_type = arg_spec.annotations.get(kwarg, None)
# Apply same filtering logic
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
runtime_kwonlyargs.append(kwarg)
if kwarg in arg_spec.annotations:
runtime_annotations[kwarg] = arg_type
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
# Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
return inspect.FullArgSpec(
args=runtime_args,
varargs=arg_spec.varargs, # Keep original varargs
varkw=arg_spec.varkw, # Keep original varkw
defaults=runtime_defaults,
kwonlyargs=runtime_kwonlyargs,
kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
annotations=runtime_annotations,
)
def __del__(self):
if self.cuda_modules:
cuda_modules = [module.cuda_module for module in self.cuda_modules.modules]
for module in set(cuda_modules):
cuda_helpers.unload_cubin_module(module)
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
"""
This function is the prune version of `generate_mlir_function_types` which only generates execution args
to get rid of mlir context.
"""
# args/kwargs must match arg_specs
# No canonicalization of args/kwargs to avoid extra latency
if len(args) != len(args_spec.args) or len(kwargs) != len(args_spec.kwonlyargs):
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"input args length": len(args),
"input kwargs length": len(kwargs),
"function signature args length": len(args_spec.args),
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
},
)
exe_args = []
input_args = [*args, *kwargs.values()]
input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
for i, arg in enumerate(input_args):
arg_type = args_spec.annotations.get(input_arg_names[i], None)
# Implicit cast to NumericMeta
if isinstance(arg_type, t.NumericMeta):
arg = t.cast(arg, arg_type)
# If not any known type, try registered adapter to do the conversion
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
adapted_arg = adapter(arg) if adapter else arg
exe_args.extend(get_c_pointers(adapted_arg))
return exe_args
def __call__(self, *args, **kwargs):
exe_args = self.generate_execution_args(args, kwargs, self.args_spec)
self.run_compiled_program(exe_args)
# Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
def get_invoke_packed_args(self, exe_args):
if self.cuda_modules:
exe_args += self.cuda_modules.args
packed_args = (ctypes.c_void_p * len(exe_args))()
for argNum in range(len(exe_args)):
packed_args[argNum] = exe_args[argNum]
return packed_args
def run_compiled_program(self, exe_args):
if self.jit_time_profiling:
profiler = timer(enable=True)
try:
packed_args = profiler(self.get_invoke_packed_args)(exe_args)
profiler(self.capi_func)(packed_args)
except Exception as e:
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
else:
try:
packed_args = self.get_invoke_packed_args(exe_args)
self.capi_func(packed_args)
except Exception as e:
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
def update_jit_cuda_modules(self, kernel_symbols):
# preload cuda module from compiled cubin in ir and store to jit_executor.kernels.
if len(kernel_symbols) > 0:
extra_args = []
module = self.ir_module
cuda_kernel_cache = dict()
cuda_driver_version = cuda_helpers.get_driver_version()
for sym in kernel_symbols:
if sym not in cuda_kernel_cache:
log().debug(f"Loading CUDA module for symbol: {sym}")
# load cuda module/get function pointer from module and cache
def walk_callback(sym, func_sym, cubin_data):
cubin_module = cuda_helpers.load_cubin_module_data(cubin_data)
kernel_ptr = cuda_helpers.get_kernel_function(
cubin_module, func_sym
)
# Enable non-portable cluster size for CUDA version 11.8 or higher.
if cuda_driver_version >= 11080:
cuda_helpers.set_kernel_attribute(
kernel_ptr,
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
1,
)
cuda_kernel_cache[sym] = CudaSingleModule(
cubin_module, kernel_ptr
)
self.walk_module_and_get_cubin_data(module, sym, walk_callback)
else:
log().debug(f"Symbol {sym} already in cache")
# check if kernel is empty.
if sym in cuda_kernel_cache:
extra_args.append(
ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr())
)
# store to the jit result if jit result is cached.
self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args)
return self
def _get_escaped_cubin_bytes(self, cubin_data):
"""This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
def ishex(inp):
return (
inp in range(0x30, 0x3A)
or inp in range(0x61, 0x67)
or inp in range(0x41, 0x47)
)
converted = bytearray()
idx = 0
while idx < len(cubin_data):
# escape the original bytes
if cubin_data[idx] == 0x5C:
# if data of idx is b'\\'
if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
converted += bytearray.fromhex(
cubin_data[idx + 1 : idx + 3].decode()
)
idx += 3
elif cubin_data[idx + 1] == 0x5C:
converted.append(cubin_data[idx])
idx += 2
else:
# no escape, directly write
converted.append(cubin_data[idx])
idx += 1
return bytes(converted)
def walk_module_and_get_cubin_data(self, module, sym, callback):
"""This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
def walk_gpu_binary_op(op):
if op.name != "gpu.binary":
return ir.WalkResult.ADVANCE
s = io.BytesIO()
op.write_bytecode(s)
cubin_data = s.getvalue()
if sym.encode() not in cubin_data:
return ir.WalkResult.ADVANCE
if (
"kernels" != op.opview.sym_name.value
and sym != op.opview.sym_name.value
):
return ir.WalkResult.ADVANCE
# function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
func_sym = sym
if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
func_sym = sym.rsplit("_", 1)[0]
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
cubin_data = self._get_escaped_cubin_bytes(cubin_data)
callback(sym, func_sym, cubin_data)
return ir.WalkResult.ADVANCE
module.operation.walk(walk_gpu_binary_op)

View File

@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides a runtime utility functions that are needed for
the DSL.
"""
from . import device_tensor
from . import dlpack_types
from . import cuda
from . import tensor_descriptor
from . import jit_arg_adapters
__all__ = [
"device_tensor",
"dlpack_types",
"cuda",
"tensor_descriptor",
"jit_arg_adapters",
]

View File

@@ -0,0 +1,470 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides CUDA Python helper functions
"""
from functools import lru_cache
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import os
import ctypes
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
# MLIR imports
from ..._mlir import ir
from ..._mlir.dialects import gpu
# Local module imports
from ..utils.logger import log as _log
from ..common import *
from .jit_arg_adapters import JitArgAdapterRegistry
# =============================================================================
# Utils
# =============================================================================
def _cudaGetErrorEnum(error):
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise DSLRuntimeError("Unknown error type: {}".format(error))
def _get_gpu_arch_info(major, minor):
"""Get GPU architecture information and compatibility details."""
gpu_arch_map = {
(7, 0): ("Volta", "sm_70", ["sm_70"]), # V100
(7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX
(8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100
(8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series
(8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series
(8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40
(9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100
(10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200
}
return gpu_arch_map.get(
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
)
def get_compute_capability_major_minor(device_id: int = 0):
"""
Returns the compute capability of the CUDA device as a tuple of (major, minor).
For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell.
Returns None on failure.
"""
try:
checkCudaErrors(cuda.cuInit(0))
device = checkCudaErrors(cuda.cuDeviceGet(device_id))
major = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
device,
)
)
minor = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
device,
)
)
return major, minor
except RuntimeError as e:
_log().info(f"Failed to get CUDA compute capability: {e}")
return None, None
@dataclass
class DeviceInfo:
"""Data class to store CUDA device information."""
device_count: int = 0
current_device: int = 0
device_name: Optional[str] = None
major_version: Optional[int] = None
minor_version: Optional[int] = None
arch_name: Optional[str] = None
sm_arch: Optional[str] = None
compatible_archs: Optional[List[str]] = None
memory_gb: Optional[float] = None
target_arch: Optional[str] = None
error_message: Optional[str] = None
initialization_failed: bool = False
def pretty_str(self) -> str:
"""
Convert DeviceInfo to a formatted string for display.
"""
info = ""
if self.initialization_failed:
return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}"
if self.error_message:
return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}"
if self.device_count > 0:
info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n"
if self.major_version is not None and self.minor_version is not None:
info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n"
info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n"
if self.memory_gb is not None:
info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n"
else:
info += f"- Compute capability: unknown\n"
info += f"- SM arch: unknown{Colors.RESET}\n"
else:
info += f"- No devices available\n"
return info
def get_device_info() -> DeviceInfo:
"""
Get detailed information about CUDA devices.
Returns a DeviceInfo dataclass with device information.
"""
device_info = DeviceInfo()
# Initialize CUDA if not already initialized
try:
result = cuda.cuInit(0)
if result[0].value: # Check for error
device_info.initialization_failed = True
return device_info
except:
pass
try:
# Get device count
result = cuda.cuDeviceGetCount()
device_info.device_count = result[1] if result[0].value == 0 else 0
if device_info.device_count > 0:
# Get current device
try:
result = cuda.cuCtxGetDevice()
if result[0].value == 0:
device_info.current_device = result[1]
except:
pass
# Get device name
try:
name_result = cuda.cuDeviceGetName(100, device_info.current_device)
if name_result[0].value == 0:
device_info.device_name = name_result[1]
except:
pass
# Get compute capability and architecture info
try:
major, minor = get_compute_capability_major_minor(
device_info.current_device
)
# Check if we successfully got the compute capability
if major is not None and minor is not None:
device_info.major_version = major
device_info.minor_version = minor
arch_name, sm_arch, compatible_archs = _get_gpu_arch_info(
device_info.major_version, device_info.minor_version
)
device_info.arch_name = arch_name
device_info.sm_arch = sm_arch
device_info.compatible_archs = compatible_archs
# Get memory info
try:
total_mem = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY,
device_info.current_device,
)
if total_mem[0].value == 0:
device_info.memory_gb = total_mem[1] / (
1024 * 1024 * 1024
) # Convert to GB
except:
pass
except Exception as e:
pass # Compute capability info will remain None
except Exception as e:
device_info.error_message = str(e)
return device_info
def checkCudaErrors(result):
"""Check CUDA errors and provide detailed error messages."""
if result[0].value:
error_code = result[0].value
error_name = _cudaGetErrorEnum(result[0])
raise DSLCudaRuntimeError(error_code, error_name)
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
# =============================================================================
# Driver Helpers
# =============================================================================
@lru_cache(maxsize=1)
def initialize_cuda_context(device_id: int = 0, flags: int = 0):
"""
Initializes the CUDA context for a specified device.
"""
# Initialize CUDA Driver API
_log().info(f"cuInit {flags}")
checkCudaErrors(cuda.cuInit(flags))
# Retrieve handle for device
_log().info(f"cuDeviceGet {device_id}")
cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id))
_log().info(f"{cuDevice} <-- cuDeviceGet")
# Create context
_log().info(f"cuCtxCreate {0} {cuDevice}")
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
_log().info(f"{context} <-- cuCtxCreate")
return context
def load_cubin_module(cubin_file):
"""
Loads a CUBIN file and returns the module.
"""
# Load CUBIN file as binary data
_log().info(f"read cubin {cubin_file}")
with open(cubin_file, "rb") as f:
cubin_data = f.read()
# Load module data
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
module = checkCudaErrors(
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
)
return module
def unload_cubin_module(module):
"""
Unloads a CUBIN module.
"""
_log().info(f"cuModuleUnload {module}")
checkCudaErrors(cuda.cuModuleUnload(module))
def load_cubin_module_data(cubin_data):
"""
Loads a CUBIN from data and returns the module.
"""
# Load module data
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
module = checkCudaErrors(
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
)
return module
def get_kernel_function(module, kernel_name):
"""
Retrieves the kernel function from the module.
"""
_log().info(f"cuModuleGetFunction {module} {kernel_name}")
kernel = checkCudaErrors(
cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8"))
)
_log().info(f"{kernel} <-- cuModuleGetFunction")
return kernel
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
"""
Launches the CUDA kernel.
"""
_log().info(
f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}"
)
checkCudaErrors(
cuda.cuLaunchKernel(
kernel,
grid_dims[0],
grid_dims[1],
grid_dims[2],
block_dims[0],
block_dims[1],
block_dims[2],
smem_size, # Shared memory size
stream,
kernel_args,
0, # Extra parameters
)
)
def stream_sync(stream):
"""
Synchronizes the CUDA stream.
"""
_log().info(f"cuStreamSynchronize {stream}")
checkCudaErrors(cuda.cuStreamSynchronize(stream))
def stream_create(id=0):
"""
Creates the CUDA stream.
"""
_log().info(f"cuStreamCreate {id}")
stream = checkCudaErrors(cuda.cuStreamCreate(id))
_log().info(f"{stream} <-- cuStreamCreate")
return stream
def stream_destroy(stream):
"""
Destroys the CUDA stream.
"""
_log().info(f"cuStreamDestroy {stream}")
checkCudaErrors(cuda.cuStreamDestroy(stream))
def context_destroy(context):
"""
Destroys the CUDA context.
"""
_log().info(f"cuCtxDestroy {context}")
checkCudaErrors(cuda.cuCtxDestroy(context))
def allocate(size_in_bytes: int, stream=None):
"""
Allocate device memory based on numpy host array size.
"""
_log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream)
if stream is None:
device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
else:
device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream))
_log().info("Allocated [%s]", device_memory)
return device_memory
def deallocate(device_pointer, stream=None):
"""
Deallocate the specified device memory pointer.
"""
_log().info(
"Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream
)
if stream is None:
checkCudaErrors(cuda.cuMemFree(device_pointer))
else:
checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream))
def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None):
"""
Copy data from host to device memory.
"""
_log().info(
"Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]",
hex(host_pointer),
hex(int(device_pointer)),
size_in_bytes,
stream,
)
if stream is None:
checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes))
else:
checkCudaErrors(
cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream)
)
def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None):
"""
Copy data from device to host memory.
"""
_log().info(
"Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]",
hex(int(device_pointer)),
hex(host_pointer),
size_in_bytes,
stream,
)
if stream is None:
checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes))
else:
checkCudaErrors(
cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream)
)
def default_stream():
return cuda.CUstream(0)
def get_driver_version():
"""
Returns the CUDA driver version.
"""
return checkCudaErrors(cuda.cuDriverGetVersion())
def set_kernel_attribute(kernel, attribute, value):
"""
Sets a CUDA kernel attribute.
"""
return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value))
@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream)
class StreamAdapter:
"""
Convert a CUDA stream to a stream representation for JIT arg generation.
"""
def __init__(self, arg):
self._arg = arg
self._c_pointer = ctypes.cast(self._arg.getPtr(), ctypes.c_void_p)
def __new_from_mlir_values__(self, values):
assert len(values) == 1
return values[0]
def __c_pointers__(self):
return [self._c_pointer]
def __get_mlir_types__(self):
return [gpu.AsyncTokenType.get()]

View File

@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import copy
from . import cuda as cuda_helpers
from .tensor_descriptor import *
from ..common import *
def allocate(tensor: TensorDescriptor, stream=None):
"""
Allocates GPU memory
"""
if tensor._check_is_managed_by_framework():
raise DSLRuntimeError(
"GPU tensors are managed by the framework and cannot be modified."
)
if not tensor.device_pointer is None:
raise DSLRuntimeError("Tensor is already allocated on the device.")
tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream)
log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
def deallocate(tensor: TensorDescriptor, stream=None):
"""
Deallocates GPU memory
"""
if tensor._check_is_managed_by_framework():
raise DSLRuntimeError(
"GPU tensors are managed by the framework and cannot be modified."
)
if tensor.device_pointer is None:
raise DSLRuntimeError("Tensor is not allocated on the device.")
log().info(
"Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer
)
cuda_helpers.deallocate(tensor.device_pointer, stream)
tensor.device_pointer = None
def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None):
"""
Copies data from host memory to the GPU memory.
If do_allocate is True, it first calls allocate
"""
log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
if do_allocate:
allocate(tensor, stream)
cuda_helpers.memcpy_h2d(
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
)
log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
return tensor
def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None):
"""
Copies data from GPU memory back to the host.
If do_deallocate is True, it calls deallocate
"""
log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
if tensor._check_is_managed_by_framework():
raise DSLRuntimeError(
"GPU tensors are managed by the framework and cannot be modified."
)
if tensor.device_pointer is None:
raise DSLRuntimeError("Tensor is not allocated on the device.")
cuda_helpers.memcpy_d2h(
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
)
if do_deallocate:
deallocate(tensor, stream)
log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
def to_gpu(tensor, stream=None) -> TensorDescriptor:
"""
Copies the tensor to the GPU memory from Host memory
"""
if isinstance(tensor, TensorDescriptor):
new_tensor = copy.copy(tensor)
copy_to_gpu(new_tensor, stream=stream)
return new_tensor
if TensorDescriptor.can_transformed_to_dlpack(tensor):
new_tensor = TensorDescriptor(tensor)
copy_to_gpu(new_tensor, stream=stream)
return new_tensor
raise DSLRuntimeError("Unsupported type")
def from_gpu(tensor, stream=None) -> TensorDescriptor:
"""
Copies the tensor to the GPU memory from Host memory
"""
if isinstance(tensor, TensorDescriptor):
new_tensor = copy.copy(tensor)
copy_from_gpu(new_tensor, stream=stream)
return new_tensor
if TensorDescriptor.can_transformed_to_dlpack(tensor):
new_tensor = TensorDescriptor(tensor)
copy_from_gpu(new_tensor, stream=stream)
return new_tensor
raise DSLRuntimeError("Unsupported type")

View File

@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides helper structs for dlpack.
DLPack is an open standard for in-memory tensor structures, enabling
seamless sharing of tensors across different frameworks.
Learn more at: https://github.com/dmlc/dlpack
"""
import ctypes
import enum
class DLDeviceType(enum.IntEnum):
"""Enums for device types based on the DLPack specification."""
kDLCPU = 1
kDLGPU = 2
kDLCPUPinned = 3
class DLDataTypeCode:
"""Enums for data type codes based on the DLPack specification.
see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
"""
kDLInt = 0
kDLUInt = 1
kDLFloat = 2
kDLOpaqueHandle = 3
kDLBfloat = 4
kDLComplex = 5
kDLBool = 6
class DLDevice(ctypes.Structure):
"""Structure representing the device information in DLPack."""
_fields_ = [
("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc.
("device_id", ctypes.c_int), # Device ID (e.g., GPU ID)
]
class DLDataType(ctypes.Structure):
"""Structure representing the data type in DLPack."""
_fields_ = [
("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat)
("bits", ctypes.c_uint8), # Number of bits per value
("lanes", ctypes.c_uint16), # Number of lanes
]
class DLTensor(ctypes.Structure):
"""Structure representing the DLTensor in DLPack."""
_fields_ = [
("data", ctypes.c_void_p), # Pointer to tensor data
("device", DLDevice), # Device info
("ndim", ctypes.c_int), # Number of dimensions
("dtype", DLDataType), # Data type
("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor
("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor
("byte_offset", ctypes.c_uint64), # Byte offset to tensor data
]

View File

@@ -0,0 +1,188 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides runtime utilities for JIT argument conversion in DSL.
"""
from functools import wraps
from typing import get_origin
# Local modules imports
from ..common import DSLRuntimeError
from ..typing import (
Constexpr,
Int32,
Float32,
Boolean,
)
def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
"""
Check if the argument spec is a constexpr.
"""
def _is_reserved_python_func_arg(arg_index, arg_name, func):
"""
Check if the argument is a reserved python function argument.
"""
if arg_index != 0:
return False
if arg_name == "self":
return True
is_classmethod = isinstance(func, classmethod) or (
hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
)
return arg_name == "cls" and is_classmethod
return (
_is_reserved_python_func_arg(arg_index, arg_name, owning_func)
or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
or (get_origin(arg_spec) is Constexpr)
)
def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
"""
Check if the argument is a constexpr.
"""
def _is_type_argument(arg, arg_annotation):
"""
Check if the argument is a type argument like Type[X]
"""
return isinstance(arg, type) and (
arg_annotation is None or get_origin(arg_annotation) is type
)
return (
is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
or _is_type_argument(arg, arg_spec)
or arg is None
)
class JitArgAdapterRegistry:
"""
A registry to keep track of the JIT argument adapters.
An adapter is a callable that converts a Python type to a type with following protocols supported:
- JitArgument
- DynamicExpression
The converted type can then be further processed by DSL to generate arguments for JIT functions.
"""
# A dictionary with key=type and value=callable
jit_arg_adapter_registry = {}
@classmethod
def register_jit_arg_adapter(cls, *dargs, **dkwargs):
"""
Register a JIT argument adapter callable
This can be used as a decorator on any callable like:
@register_jit_arg_adapter(my_py_type)
def my_adapter_for_my_py_type(arg):
...
@register_jit_arg_adapter(my_py_type)
class MyAdapterForMyPythonType:
...
The adapters are registered per type. If a type is already registerd, an error will be raised.
"""
def decorator(*dargs, **dkwargs):
darg_python_ty = dargs[0]
@wraps(darg_python_ty)
def wrapper(*args, **kwargs):
if len(args) != 1 or not callable(args[0]):
raise DSLRuntimeError(
"a callable must be provided for registering JIT argument adapter"
)
adapter = args[0]
if darg_python_ty in cls.jit_arg_adapter_registry:
raise DSLRuntimeError(
f"JIT argument adapter for {darg_python_ty} is already registered!",
context={
"Registered adapter": cls.jit_arg_adapter_registry[
darg_python_ty
],
"Adapter to be registered": adapter,
},
)
cls.jit_arg_adapter_registry[darg_python_ty] = adapter
return adapter
return wrapper
if len(dargs) > 0:
return decorator(*dargs, **dkwargs)
else:
raise DSLRuntimeError(
"a Python type must be provided for registering JIT argument adapter"
)
@classmethod
def get_registered_adapter(cls, ty):
"""
Get the registered JIT argument adapter for the given type.
"""
return cls.jit_arg_adapter_registry.get(ty, None)
# =============================================================================
# JIT Argument Adapters
# =============================================================================
@JitArgAdapterRegistry.register_jit_arg_adapter(int)
@JitArgAdapterRegistry.register_jit_arg_adapter(float)
@JitArgAdapterRegistry.register_jit_arg_adapter(bool)
def _convert_python_scalar(arg):
"""
Convert a Python scalar to a DSL type.
"""
conversion_map = {
int: Int32,
float: Float32,
bool: Boolean,
}
return conversion_map.get(type(arg))(arg)
@JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
@JitArgAdapterRegistry.register_jit_arg_adapter(list)
def _convert_python_sequence(arg):
"""
Go through each element in the sequence and convert it to a type that can be
further processed by DSL to generate the corresponding JIT argument(s).
"""
adapted_arg = []
for elem in arg:
adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
if adapter is not None:
converted_elem = adapter(elem)
adapted_arg.append(converted_elem)
else:
# If no registered adapter is found, just return the original element
adapted_arg.append(elem)
assert len(adapted_arg) == len(arg)
return type(arg)(adapted_arg)

View File

@@ -0,0 +1,201 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
# Helpers
import itertools, operator
import ctypes
from . import dlpack_types as _dpack
from .dlpack_runtime import (
dlpack_to_tensor_desc,
get_tensor_desc_data_ptr,
get_tensor_desc_is_in_device,
get_tensor_desc_element_type,
get_tensor_desc_shape,
get_tensor_desc_stride,
get_tensor_desc_element_size_in_bytes,
get_tensor_desc_ndim,
get_tensor_desc_dtype_code,
get_tensor_desc_dtype_bits,
get_tensor_desc_device_type,
get_tensor_desc_device_id,
)
from ..utils.logger import log
from ..common import *
from ..typing import (
Boolean,
Float8E5M2,
Int64,
Int32,
Int16,
Int8,
Uint64,
Uint32,
Uint16,
Uint8,
Float64,
Float32,
Float16,
BFloat16,
)
class TensorDescriptor:
def __init__(self, tensor):
"""Initialize with a tensor that supports the DLPack protocol.
Args:
tensor: Any tensor object that implements __dlpack__ and __dlpack_device__
"""
self.tensor = tensor
self._capsule = dlpack_to_tensor_desc(tensor)
self.data_ptr = get_tensor_desc_data_ptr(self._capsule)
self.device_type = get_tensor_desc_device_type(self._capsule)
self.device_type = _dpack.DLDeviceType(self.device_type)
if self.device_type == _dpack.DLDeviceType.kDLGPU:
self.device_pointer = self.data_ptr
elif self.device_type == _dpack.DLDeviceType.kDLCPU:
self.device_pointer = None
else:
raise DSLRuntimeError(
f"DLPack device type is not supported {self.dl_tensor.device.device_type}"
)
log().info("TensorDescriptor is created = [%s]", self)
@staticmethod
def can_transformed_to_dlpack(dl_tensor):
if not hasattr(dl_tensor, "__dlpack__") or not hasattr(
dl_tensor, "__dlpack_device__"
):
return False
return True
@property
def is_in_device(self):
"""Check if the tensor is stored on a device."""
return not self.device_pointer is None
@property
def device_id(self):
"""Return device id where tensor resides."""
if self.is_in_device:
return get_tensor_desc_device_id(self._capsule)
return -1
@property
def element_type(self):
"""Return the corresponding Python type based on DLPack dtype metadata."""
str_element_type = get_tensor_desc_element_type(self._capsule)
dtype_map = {
# bool is 8bit from numpy and torch
"Bool": Boolean,
"Int64": Int64,
"Int32": Int32,
"Int16": Int16,
"Int8": Int8,
"UInt64": Uint64,
"UInt32": Uint32,
"UInt16": Uint16,
"UInt8": Uint8,
"Float64": Float64,
"Float32": Float32,
"Float16": Float16,
"BFloat16": BFloat16,
"Float8E5M2": Float8E5M2,
}
if str_element_type not in dtype_map:
raise KeyError(
f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}"
)
return dtype_map[str_element_type]
@property
def shape(self):
"""Return the shape of the tensor."""
return get_tensor_desc_shape(self._capsule)
@property
def rank(self):
"""Return the rank of the tensor."""
return get_tensor_desc_ndim(self._capsule)
@property
def strides(self):
"""Return the rank of the tensor."""
return get_tensor_desc_stride(self._capsule)
@property
def element_size_in_bytes(self):
"""Calculate the element size in bytes of the DLPack tensor."""
return get_tensor_desc_element_size_in_bytes(self._capsule)
@property
def size_in_bytes(self):
"""Calculate the total size in bytes of the DLPack tensor."""
# Calculate the number of elements using the shape
ndim = get_tensor_desc_ndim(self._capsule)
shape = get_tensor_desc_shape(self._capsule)
num_elements = 1
for i in range(ndim):
num_elements *= shape[i]
# Total bytes
total_bytes = self.element_size_in_bytes * num_elements
return total_bytes
def __str__(self):
"""Return a compact string representation of the device_tensor with a tensor prefix."""
# Extract shape
shape = "x".join(map(str, self.shape))
# Extract dtype
dtype_code = get_tensor_desc_dtype_code(self._capsule)
dtype_bits = get_tensor_desc_dtype_bits(self._capsule)
dtype = (
f"i{dtype_bits}"
if dtype_code == _dpack.DLDataTypeCode.kDLInt
else f"f{dtype_bits}"
)
# Extract device
device_type = "cpu" if not self.is_in_device else "gpu"
return f"tensor<{shape}x{dtype}>_{device_type}"
def _check_is_managed_by_framework(self):
"""
Ensure the tensor is not managed by the framework (e.g., GPU tensor).
Raises an exception if the tensor is framework-managed.
"""
return self.device_type == _dpack.DLDeviceType.kDLGPU
def from_tensor(tensor) -> TensorDescriptor:
"""Create a TensorDescriptor from a tensor object."""
return TensorDescriptor(tensor)
def to_tensor(tensor_descriptor: TensorDescriptor):
"""Return tensor object from tensor descriptor."""
return tensor_descriptor.tensor
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from . import stacktrace
from . import logger
from . import timer
__all__ = [
"logger",
"timer",
"stacktrace",
]

View File

@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides logging helper functions
"""
import logging
logger = None
def log():
return logger
def setup_log(
name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1
):
"""Set up and configure a logger with console and/or file handlers.
:param name: Name of the logger to create
:type name: str
:param log_to_console: Whether to enable logging to console, defaults to False
:type log_to_console: bool, optional
:param log_to_file: Whether to enable logging to file, defaults to False
:type log_to_file: bool, optional
:param log_file_path: Path to the log file, required if log_to_file is True
:type log_file_path: str, optional
:param log_level: Logging level to set, defaults to 1
:type log_level: int, optional
:raises ValueError: If log_to_file is True but log_file_path is not provided
:return: Configured logger instance
:rtype: logging.Logger
"""
# Create a custom logger
global logger
logger = logging.getLogger(name)
if log_to_console or log_to_file:
logger.setLevel(log_level)
else:
logger.setLevel(logging.NOTSET)
# Clear existing handlers to prevent duplicate logs
if logger.hasHandlers():
logger.handlers.clear()
# Define formatter
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s"
)
# Add console handler if enabled
if log_to_console:
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Add file handler if enabled
if log_to_file:
if not log_file_path:
raise ValueError("log_file_path must be provided when enable_file is True")
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
logger = setup_log("generic")

View File

@@ -0,0 +1,165 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides stacktrace helper functions
"""
import os
import re
def walk_to_top_module(start_path):
"""
Walk up from the start_path to find the top-level Python module.
:param start_path: The path to start from.
:return: The path of the top-level module.
"""
current_path = start_path
while True:
# Check if we are at the root directory
if os.path.dirname(current_path) == current_path:
break
# Check for __init__.py
init_file_path = os.path.join(current_path, "__init__.py")
if os.path.isfile(init_file_path):
# If __init__.py exists, move up one level
current_path = os.path.dirname(current_path)
else:
# If no __init__.py, we are not in a module; stop
break
# If we reached the root without finding a module, return None
if os.path.dirname(current_path) == current_path and not os.path.isfile(
os.path.join(current_path, "__init__.py")
):
return None
# Return the path of the top-level module
return current_path
def _filter_internal_frames(traceback, internal_path):
"""
Filter out stack frames from the traceback that belong to the specified module path.
This function removes stack frames from the traceback whose file paths start with
the given prefix_path, effectively hiding internal implementation details from
the error traceback shown to users.
"""
iter_prev = None
iter_tb = traceback
while iter_tb is not None:
if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith(
internal_path
):
if iter_tb.tb_next:
if iter_prev:
iter_prev.tb_next = iter_tb.tb_next
else:
traceback = iter_tb.tb_next
else:
iter_prev = iter_tb
iter_tb = iter_tb.tb_next
return traceback
_generated_function_names = re.compile(
r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$"
)
def _filter_duplicated_frames(traceback):
"""
Filter out duplicated stack frames from the traceback.
The function filters out consecutive frames that are in the same file and have the same line number.
In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame.
"""
iter_prev = None
iter_tb = traceback
while iter_tb is not None:
skip_current = False
skip_next = False
if iter_tb.tb_next:
current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename)
next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename)
# if in the same file, check if the line number is the same
if current_filename == next_filename:
current_lineno = iter_tb.tb_lineno
next_lineno = iter_tb.tb_next.tb_lineno
if current_lineno == next_lineno:
# Same file and line number, check name, if current is generated, skip current, otherwise skip next
name = iter_tb.tb_frame.f_code.co_name
is_generated = bool(_generated_function_names.match(name))
if is_generated:
# Skip current
skip_current = True
else:
# Skip next if it's generated, otherwise keep both
next_name = iter_tb.tb_next.tb_frame.f_code.co_name
skip_next = bool(_generated_function_names.match(next_name))
if skip_current:
if iter_prev:
iter_prev.tb_next = iter_tb.tb_next
else:
traceback = iter_tb.tb_next
elif skip_next:
# if next is last frame, don't skip
if iter_tb.tb_next.tb_next:
iter_tb.tb_next = iter_tb.tb_next.tb_next
iter_prev = iter_tb
else:
iter_prev = iter_tb
iter_tb = iter_tb.tb_next
return traceback
def filter_stackframe(traceback, prefix_path):
"""
Filter out stack frames from the traceback that belong to the specified module path.
This function removes stack frames from the traceback whose file paths start with
the given prefix_path, effectively hiding internal implementation details from
the error traceback shown to users.
:param traceback: The traceback object to filter.
:param prefix_path: The path prefix to filter out from the traceback.
:return: The filtered traceback with internal frames removed.
"""
# Step 1: filter internal frames
traceback = _filter_internal_frames(traceback, prefix_path)
# Step 2: consolidate duplicated frames
return _filter_duplicated_frames(traceback)
def filter_exception(value, module_dir):
"""
Filter out internal implementation details from exception traceback.
This function recursively processes an exception and its cause chain,
removing stack frames that belong to the specified module directory.
This helps to present cleaner error messages to users by hiding
implementation details.
:param value: The exception object to filter.
:param module_dir: The module directory path to filter out from tracebacks.
:return: The filtered exception with internal frames removed.
"""
if hasattr(value, "__cause__") and value.__cause__:
filter_exception(value.__cause__, module_dir)
if hasattr(value, "__traceback__"):
filter_stackframe(value.__traceback__, module_dir)

View File

@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides a timing helper functions
"""
from functools import wraps
from .logger import log
# TODO: revisit this part when mlir timing manager is ready for pybind.
def timer(*dargs, **kwargs):
enable = kwargs.get("enable", True)
def decorator(func):
@wraps(func)
def func_wrapper(*args, **kwargs):
if not enable:
return func(*args, **kwargs)
from time import time
start = time()
result = func(*args, **kwargs)
end = time()
# Convert time from seconds to us
spend_us = (end - start) * 1e6
# Determine the function type and format the log message
if hasattr(func, "__name__"):
func_name = func.__name__
log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs"
elif "CFunctionType" in str(type(func)):
log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs"
else:
log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs"
log().info(log_message)
return result
return func_wrapper
if len(dargs) == 1 and callable(dargs[0]):
return decorator(dargs[0])
else:
return decorator

View File

@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .cutlass_dsl import (
Constexpr,
as_numeric,
min,
max,
and_,
or_,
all_,
any_,
not_,
all_,
any_,
select_,
# Control-flow without AST pre-processor
if_generate,
for_generate,
LoopUnroll,
while_generate,
yield_out,
# Control-flow with AST pre-processor
range_constexpr,
range_dynamic,
const_expr,
dynamic_expr,
# Data types
dtype, # Provides conversions to types inheriting from NumericType
DSLRuntimeError,
JitArgAdapterRegistry,
# Construction utilities for user-defined classes
extract_mlir_values,
new_from_mlir_values,
)
from .cute.typing import *
# Utilities not belonging to CuTe
from . import utils as utils
# Used as internal symbol
from . import cutlass_dsl as _dsl
# Aliases
LaunchConfig = _dsl.BaseDSL.LaunchConfig
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
gpu = _dsl.cutlass_gpu
cuda = _dsl.cuda_helpers

View File

@@ -0,0 +1,310 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
# Use the auto-generated enum AddressSpace
from cutlass._mlir.dialects.cute import AddressSpace
# Explicitly import types that might be directly used by other modules.
# This is a fix for using Sphinx to generate documentation
# Because Sphinx processes each module in isolation, it won't be able to rely
# on re-exported symbols via wildcard imports (from .typing import *) in the
# same way that Python does at runtime.
from .typing import (
Shape,
Stride,
IntTuple,
Coord,
Tile,
XTuple,
Tiler,
Layout,
Pointer,
Tensor,
)
# Import everything else
from .typing import *
from .core import (
assume,
is_integer,
is_int_tuple,
is_static,
size,
has_underscore,
slice_,
make_ptr,
make_layout,
recast_layout,
make_fragment_like,
depth,
rank,
flatten_to_tuple,
flatten,
unflatten,
product,
product_like,
shape,
size_in_bytes,
make_identity_layout,
make_ordered_layout,
make_composed_layout,
make_layout_tv,
make_swizzle,
recast_ptr,
make_tensor,
make_identity_tensor,
make_fragment,
recast_tensor,
get,
select,
front,
is_major,
find,
coalesce,
group_modes,
cosize,
dice,
product_each,
prepend,
append,
prepend_ones,
append_ones,
ceil_div,
slice_and_offset,
crd2idx,
domain_offset,
elem_less,
transform_leaf,
filter_zeros,
filter,
tile_to_shape,
shape_div,
composition,
complement,
right_inverse,
left_inverse,
max_common_layout,
max_common_vector,
logical_product,
zipped_product,
tiled_product,
flat_product,
raked_product,
blocked_product,
flat_divide,
logical_divide,
zipped_divide,
tiled_divide,
local_partition,
local_tile,
printf,
print_tensor,
# tiled mma/tiled copy
make_mma_atom,
make_tiled_mma,
make_copy_atom,
make_tiled_copy_tv,
make_tiled_copy,
make_tiled_copy_S,
make_tiled_copy_D,
make_tiled_copy_C_atom,
basic_copy,
basic_copy_if,
autovec_copy,
copy,
gemm,
# Wrapper classes
ComposedLayout,
Swizzle,
E,
Atom,
MmaAtom,
CopyAtom,
TiledCopy,
TiledMma,
TensorSSA,
ReductionOp,
full,
full_like,
empty_like,
ones_like,
zeros_like,
where,
any_,
all_,
# User defined struct
struct,
pretty_str,
make_layout_image_mask,
repeat_like,
round_up,
is_congruent,
is_weakly_congruent,
ScaledBasis,
get_divisibility,
Ratio,
)
from . import arch
from . import nvgpu
from . import testing
from . import runtime
# Export all math ops without "math."
from .math import *
# Used as internal symbol
from .. import cutlass_dsl as _dsl
# Aliases
jit = _dsl.CuTeDSL.jit
kernel = _dsl.CuTeDSL.kernel
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
compile = _dsl.compile
# Explicitly export all symbols for documentation generation
__all__ = [
# Core types
"AddressSpace",
"Tensor",
"Layout",
"ComposedLayout",
"Swizzle",
"E",
"Atom",
"MmaAtom",
"CopyAtom",
"TiledCopy",
"TiledMma",
"TensorSSA",
# Basic utility functions
"assume",
"is_integer",
"is_int_tuple",
"is_static",
"size",
"has_underscore",
"slice_",
"depth",
"rank",
"shape",
"printf",
"print_tensor",
"pretty_str",
# Layout functions
"make_layout",
"recast_layout",
"make_identity_layout",
"make_ordered_layout",
"make_composed_layout",
"make_layout_tv",
"make_layout_image_mask",
# Tensor functions
"make_ptr",
"make_tensor",
"make_identity_tensor",
"make_fragment",
"make_fragment_like",
"recast_ptr",
"recast_tensor",
# Tensor manipulation
"get",
"select",
"front",
"is_major",
"find",
"coalesce",
"group_modes",
"cosize",
"size_in_bytes",
# Tuple operations
"flatten_to_tuple",
"flatten",
"product",
"product_like",
"product_each",
"prepend",
"append",
"prepend_ones",
"append_ones",
# Math operations
"ceil_div",
"round_up",
# Layout operations
"slice_and_offset",
"crd2idx",
"domain_offset",
"elem_less",
"filter_zeros",
"filter",
"tile_to_shape",
"shape_div",
"dice",
# Layout algebra
"composition",
"complement",
"right_inverse",
"left_inverse",
"max_common_layout",
"max_common_vector",
"is_congruent",
"is_weakly_congruent",
# Product operations
"logical_product",
"zipped_product",
"tiled_product",
"flat_product",
"raked_product",
"blocked_product",
# Division operations
"flat_divide",
"logical_divide",
"zipped_divide",
"tiled_divide",
"local_partition",
"local_tile",
# MMA and Copy operations
"make_mma_atom",
"make_tiled_mma",
"make_copy_atom",
"make_tiled_copy_tv",
"make_tiled_copy",
"make_tiled_copy_C_atom",
"basic_copy",
"basic_copy_if",
"autovec_copy",
"copy",
"gemm",
# Tensor creation
"full",
"full_like",
"empty_like",
"ones_like",
"zeros_like",
"where",
"any_",
"all_",
"repeat_like",
"ScaledBasis",
# User defined struct
"struct",
# Modules
"arch",
"nvgpu",
"testing",
"runtime",
# Decorators and code generation
"jit",
"kernel",
"register_jit_arg_adapter",
"compile",
]

View File

@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .elect import *
from .mbar import *
from .nvvm_wrappers import *
from .smem import *
from .tmem import *
# __all__ is required here for documentation generation
__all__ = [
#
# elect.py
#
"make_warp_uniform",
"elect_one",
#
# mbar.py
#
"mbarrier_init_arrive_cnt",
"mbarrier_init_fence",
"mbarrier_init_tx_bytes",
"mbarrier_wait",
"mbarrier_try_wait",
"conditional_mbarrier_try_wait",
"mbarrier_arrive",
#
# nvvm_wrappers.py
#
"lane_idx",
"warp_idx",
"thread_idx",
"block_dim",
"block_idx",
"grid_dim",
"cluster_idx",
"cluster_dim",
"block_in_cluster_idx",
"block_in_cluster_dim",
"block_idx_in_cluster",
"shuffle_sync",
"shuffle_sync_up",
"shuffle_sync_down",
"shuffle_sync_bfly",
"barrier",
"sync_threads",
"sync_warp",
"fence_acq_rel_cta",
"fence_acq_rel_cluster",
"fence_acq_rel_gpu",
"fence_acq_rel_sys",
"cp_async_commit_group",
"cp_async_wait_group",
"cp_async_bulk_commit_group",
"cp_async_bulk_wait_group",
"cluster_wait",
"cluster_arrive",
"cluster_arrive_relaxed",
"fence_proxy",
"vote_ballot_sync",
"popc",
"fence_view_async_tmem_load",
"fence_view_async_tmem_store",
"warpgroup_reg_alloc",
"warpgroup_reg_dealloc",
"fma_packed_f32x2",
"mul_packed_f32x2",
"add_packed_f32x2",
"fmax",
"rcp_approx",
"exp2",
# Constants
"WARP_SIZE",
# Forward from auto-generated nvvm python
"ProxyKind",
"SharedSpace",
"RoundingModeKind",
#
# smem.py
#
"alloc_smem",
"get_dyn_smem",
#
# tmem.py
#
"retrieve_tmem_ptr",
"alloc_tmem",
"relinquish_tmem_alloc_permit",
"dealloc_tmem",
]

View File

@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir.dialects import nvvm, scf
from cutlass._mlir import ir
from ..typing import Int, Int32
from ...impl_utils import check_value_in
@dsl_user_op
def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32:
"""
Creates a warp-uniform value from the given integer input.
:param value: The integer to make warp uniform.
:type value: Int
:return: The warp-uniform value equal to the input.
:rtype: Int32
"""
return Int32(
_cute_nvgpu_ir.arch_make_warp_uniform(
Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
)
class IfOpRegion:
"""
A context manager for if Op.
Automatically inserts `scf.yield([])` when exiting the context.
"""
def __init__(self, block, *, loc=None, ip=None):
self.block = block
self.insert_point = ir.InsertionPoint(self.block)
self.loc = loc
self.ip = ip
def __enter__(self):
self.insert_point.__enter__()
return self.block.arguments
def __exit__(self, exc_type, exc_value, traceback):
scf.yield_([], loc=self.loc, ip=self.ip)
self.insert_point.__exit__(exc_type, exc_value, traceback)
@dsl_user_op
def elect_one(*, loc=None, ip=None) -> IfOpRegion:
"""
Elects one thread within a warp.
.. code-block:: python
with elect_one():
# Only one thread in the warp executes the code in this context
pass
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
is_thread_leader = nvvm.elect_sync(T.bool())
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)

View File

@@ -0,0 +1,208 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
from cutlass._mlir.dialects import nvvm
from cutlass._mlir import ir
from ..typing import Pointer, Int, Boolean, Int32
from ...impl_utils import check_value_in
####################################################################################################
#
# Mbarrier management utilities
#
####################################################################################################
@dsl_user_op
def mbarrier_init_arrive_cnt(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
"""
Initializes a mbarrier with the specified thread arrival count.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param cnt: The arrival count of the mbarrier
:type cnt: Int
"""
nvvm.mbarrier_init_shared(
mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
@dsl_user_op
def mbarrier_init_fence(*, loc=None, ip=None) -> None:
"""
A fence operation that applies to the mbarrier initializations.
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
@dsl_user_op
def mbarrier_init_tx_bytes(
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
) -> None:
"""
Initializes a mbarrier with the specified number of transaction bytes.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param bytes: The number of transaction bytes
:type bytes: Int
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
mbar_llvm_ptr.type,
mbar_llvm_ptr,
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
space = nvvm.MBarrierSpaceKind.CLUSTER
else:
space = nvvm.MBarrierSpaceKind.CTA
nvvm.mbarrier_txn(
mbar_llvm_ptr,
Int32(bytes).ir_value(loc=loc, ip=ip),
kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX,
space=space,
loc=loc,
ip=ip,
)
@dsl_user_op
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
"""
Waits on a mbarrier with a specified phase.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param phase: The phase to wait for (either 0 or 1)
:type phase: Int
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
timeout_ns = 10000000
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
# The timeout in ns only applies to the latter and this call is truly blocking
nvvm.mbarrier_try_wait_parity_shared(
mbar_ptr.llvm_ptr,
Int32(phase).ir_value(loc=loc, ip=ip),
Int32(timeout_ns).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
@dsl_user_op
def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean:
"""
Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param phase: The phase to wait for (either 0 or 1)
:type phase: Int
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
return Boolean(
nvvm.mbarrier_wait_parity(
T.bool(),
mbar_ptr.llvm_ptr,
Int32(phase).ir_value(loc=loc, ip=ip),
nvvm.MBarrierWaitKind.TRY,
loc=loc,
ip=ip,
)
)
@dsl_user_op
def conditional_mbarrier_try_wait(
cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
) -> Boolean:
"""
Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
:param cond: A boolean predicate
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param phase: The phase to wait for (either 0 or 1)
:type phase: Int
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
return if_generate(
cond,
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
lambda: Boolean(True).ir_value(loc=loc, ip=ip),
None,
[Boolean],
)
@dsl_user_op
def mbarrier_arrive(
mbar_ptr: Pointer, peer_cta_rank_in_cluster: Int = None, *, loc=None, ip=None
) -> None:
"""
Arrives on an mbarrier.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
mbar_llvm_ptr.type,
mbar_llvm_ptr,
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
space = nvvm.MBarrierSpaceKind.CLUSTER
else:
space = nvvm.MBarrierSpaceKind.CTA
nvvm.mbarrier_txn(
mbar_llvm_ptr,
Int32(1).ir_value(loc=loc, ip=ip),
kind=nvvm.MBarrierTxnKind.ARRIVE,
space=space,
loc=loc,
ip=ip,
)

View File

@@ -0,0 +1,547 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from functools import partial
from typing import Optional, Tuple, Union, Callable
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm, nvvm, vector
# Forward nvvm enums
from cutlass._mlir.dialects.nvvm import (
ProxyKind,
SharedSpace,
Tcgen05WaitKind,
SetMaxRegisterAction,
RoundingModeKind,
)
from ..typing import Int, Boolean, Int32, Float32, Numeric, as_numeric
WARP_SIZE = 32
FULL_MASK = 0xFFFFFFFF
@dsl_user_op
def lane_idx(*, loc=None, ip=None) -> Int32:
"""
Returns the lane index of the current thread within the warp.
"""
return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip))
@dsl_user_op
def warp_idx(*, loc=None, ip=None) -> Int32:
"""
Returns the warp index within a CTA.
"""
warp_size = 32
tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip))
tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip))
tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip))
ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip))
ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip))
tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y
return tid // warp_size
@dsl_user_op
def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the thread index within a CTA.
"""
return (
Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the number of threads in each dimension of the CTA.
"""
return (
Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the CTA identifier within a grid.
"""
return (
Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the number of CTAs in each dimension of the grid.
"""
return (
Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the cluster identifier within a grid.
"""
return (
Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the number of clusters in each dimension of the grid.
"""
return (
Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the CTA index within a cluster across all dimensions.
"""
return (
Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
"""
Returns the dimensions of the cluster.
"""
return (
Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)),
Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)),
)
@dsl_user_op
def block_idx_in_cluster(*, loc=None, ip=None) -> Int32:
"""
Returns the linearized identifier of the CTA within the cluster.
"""
return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip))
@dsl_user_op
def shuffle_sync_op(
value: Numeric,
offset: Int,
mask: Int = FULL_MASK,
mask_and_clamp: Int = WARP_SIZE - 1,
kind: nvvm.ShflKind = nvvm.ShflKind.idx,
*,
loc=None,
ip=None,
) -> Numeric:
"""
Shuffles a value within the threads of a warp.
:param value: The value to shuffle
:type value: Numeric
:param mask: A mask describing the threads participating in this operation
:type mask: Int
:param offset: A source lane or a source lane offset depending on kind
:type offset: Int
:param mask_and_clamp: An integer containing two packed values specifying a mask for logically
splitting warps into sub-segments and an upper bound for clamping the
source lane index.
:type mask_and_clamp: Int
:param kind: The kind of shuffle, can be idx, up, down, or bfly
:type kind: ShflKind
:return: The shuffled value
:rtype: Numeric
"""
if not isinstance(value, Numeric):
value = as_numeric(value)
return type(value)(
nvvm.shfl_sync(
type(value).mlir_type,
Int32(mask).ir_value(loc=loc, ip=ip),
value.ir_value(loc=loc, ip=ip),
Int32(offset).ir_value(loc=loc, ip=ip),
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
kind,
loc=loc,
ip=ip,
)
)
shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)
shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down)
shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly)
@dsl_user_op
def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None:
"""
Creates a barrier, optionally named.
"""
if barrier_id is not None:
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
if number_of_threads is not None:
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
nvvm.barrier(
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
)
@dsl_user_op
def sync_threads(*, loc=None, ip=None) -> None:
"""
Synchronizes all threads within a CTA.
"""
nvvm.barrier(loc=loc, ip=ip)
@dsl_user_op
def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None:
"""
Performs a warp-wide sync with an optional mask.
"""
nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
@dsl_user_op
def fence_acq_rel_cta(*, loc=None, ip=None) -> None:
"""
Fence operation with acquire-release semantics.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
nvvm.fence_acq_rel_cta(loc=loc, ip=ip)
@dsl_user_op
def fence_acq_rel_cluster(*, loc=None, ip=None) -> None:
"""
Fence operation with acquire-release semantics.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
nvvm.fence_acq_rel_cluster(loc=loc, ip=ip)
@dsl_user_op
def fence_acq_rel_gpu(*, loc=None, ip=None) -> None:
"""
Fence operation with acquire-release semantics.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
nvvm.fence_acq_rel_gpu(loc=loc, ip=ip)
@dsl_user_op
def fence_acq_rel_sys(*, loc=None, ip=None) -> None:
"""
Fence operation with acquire-release semantics.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
nvvm.fence_acq_rel_sys(loc=loc, ip=ip)
@dsl_user_op
def cp_async_commit_group(*, loc=None, ip=None) -> None:
"""
Commits all prior initiated but uncommitted cp.async instructions.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group>`__.
"""
nvvm.cp_async_commit_group(loc=loc, ip=ip)
@dsl_user_op
def cp_async_wait_group(n, *, loc=None, ip=None) -> None:
"""
Waits till only a specified numbers of cp.async groups are pending.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group-cp-async-wait-all>`__.
"""
nvvm.cp_async_wait_group(n, loc=loc, ip=ip)
@dsl_user_op
def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None:
"""
Commits all prior initiated but uncommitted cp.async.bulk instructions.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-commit-group>`__.
"""
nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
@dsl_user_op
def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None:
"""
Waits till only a specified numbers of cp.async.bulk groups are pending.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-wait-group>`__.
"""
nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip)
@dsl_user_op
def cluster_wait(*, loc=None, ip=None) -> None:
"""
A cluster-wide wait operation.
"""
nvvm.cluster_wait(loc=loc, ip=ip)
@dsl_user_op
def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None:
"""
A cluster-wide arrive operation.
"""
nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip)
@dsl_user_op
def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None:
"""
A cluster-wide arrive operation with relaxed semantics.
"""
nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip)
@dsl_user_op
def fence_proxy(
kind: ProxyKind,
*,
space: Optional[SharedSpace] = None,
use_intrinsic=None,
loc=None,
ip=None,
) -> None:
nvvm.fence_proxy(
kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip
)
@dsl_user_op
def vote_ballot_sync(
pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None
) -> Int32:
"""
Performs a ballot operation across the warp.
"""
return Int32(
nvvm.vote_ballot_sync(
T.i32(),
Int32(mask).ir_value(loc=loc, ip=ip),
Boolean(pred).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
@dsl_user_op
def popc(value: Numeric, *, loc=None, ip=None) -> Numeric:
"""
Performs a population count operation.
"""
if not isinstance(value, Numeric):
value = as_numeric(value)
return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip))
@dsl_user_op
def fence_view_async_tmem_op(
kind: Tcgen05WaitKind,
*,
loc=None,
ip=None,
) -> None:
"""
Perform a fence operation on the async TMEM load or store.
.. note::
This function is only available on sm_100a and above.
The fence is required to synchronize the TMEM load/store
and let the pipeline release or commit the buffer.
Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM.
```
# Start to copy ACC from TMEM to register
cute.copy(tmem_load, tACC, rACC)
fence_view_async_tmem_load()
# After fence, we can ensure the TMEM buffer is consumed totally.
# Release the buffer to let the MMA know it can overwrite the buffer.
mma2accum_pipeline.consumer_release(curr_consumer_state)
```
Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM.
```
# Start to copy A from register to TMEM
cute.copy(tmem_store, rA, tA)
fence_view_async_tmem_store()
# After fence, we can ensure the TMEM buffer is ready.
# Commit the buffer to let the MMA know it can start to load A.
tmem_mma_pipeline.producer_commit(curr_producer_state)
```
:param kind: The kind of fence operation to perform including LOAD and STORE.
:type kind: Tcgen05WaitKind
"""
nvvm.tcgen05_wait(kind, loc=loc, ip=ip)
fence_view_async_tmem_load = partial(
fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD
)
fence_view_async_tmem_store = partial(
fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE
)
@dsl_user_op
def warpgroup_reg_realloc_op(
reg_count: int,
kind: SetMaxRegisterAction,
*,
loc=None,
ip=None,
) -> None:
nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip)
warpgroup_reg_alloc = partial(
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase
)
warpgroup_reg_dealloc = partial(
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease
)
@dsl_user_op
def calc_packed_f32x2_op(
src_a: Tuple[Float32, Float32],
src_b: Tuple[Float32, Float32],
src_c: Tuple[Float32, Float32] | None,
calc_func: Callable,
*,
rnd=RoundingModeKind.RZ,
ftz=True,
loc=None,
ip=None,
) -> Tuple[Float32, Float32]:
vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc)
vec_src_a = vector.from_elements(
vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip
)
vec_src_b = vector.from_elements(
vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip
)
if src_c is not None:
vec_src_c = vector.from_elements(
vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip
)
vec_res = calc_func(
vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip
)
else:
vec_res = calc_func(
vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip
)
res0 = Float32(
vector.extract(
vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip
)
)
res1 = Float32(
vector.extract(
vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip
)
)
return res0, res1
fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2)
mul_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2
)
add_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
)
@dsl_user_op
def fmax(
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
) -> Float32:
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
@dsl_user_op
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
return Float32(
nvvm.rcp_approx_ftz_f(
T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
)
@dsl_user_op
def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
return Float32(
llvm.inline_asm(
T.f32(),
[Float32(a).ir_value(loc=loc, ip=ip)],
"ex2.approx.ftz.f32 $0, $1;",
"=f,f",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)

View File

@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional, Type
from cutlass.cutlass_dsl import T, dsl_user_op
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..typing import Pointer, Numeric, NumericMeta
@dsl_user_op
def alloc_smem(
element_type: Type[Numeric],
size_in_elems: int,
alignment: Optional[int] = None,
*,
loc=None,
ip=None,
) -> Pointer:
"""
Statically allocates SMEM.
:param element_type: The pointee type of the pointer.
:type element_type: Type[Numeric]
:param size_in_elems: The size of the allocation in terms of number of elements of the
pointee type
:type size_in_elems: int
:param alignment: An optional pointer alignment for the allocation
:type alignment: int
:return: A pointer to the start of the allocation
:rtype: Pointer
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(
f"element_type must be a type of Numeric, but got {element_type}"
)
if alignment is None:
# Default alignment based on the element type's width
alignment = element_type.width // 8
ptr_ty = _cute_ir.PtrType.get(
element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment
)
return _cute_nvgpu_ir.arch_alloc_smem(
ptr=ptr_ty,
input=ir.IntegerAttr.get(T.i32(), size_in_elems),
loc=loc,
ip=ip,
)
@dsl_user_op
def get_dyn_smem(
element_type: Type[Numeric],
alignment: Optional[int] = None,
*,
loc=None,
ip=None,
) -> Pointer:
"""
Retrieves a pointer to a dynamic SMEM allocation.
:param element_type: The pointee type of the pointer.
:type element_type: Type[Numeric]
:param alignment: An optional pointer alignment, the result pointer is offset appropriately
:type alignment: int
:return: A pointer to the start of the dynamic SMEM allocation with a correct
alignement
:rtype: Pointer
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(
f"element_type must be a type of Numeric, but got {element_type}"
)
if alignment is None:
# Default alignment based on the element type's width
alignment = element_type.width // 8
ptr_ty = _cute_ir.PtrType.get(
element_type.mlir_type,
_cute_ir.AddressSpace.smem,
alignment,
)
return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)

View File

@@ -0,0 +1,142 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Type
from cutlass.cutlass_dsl import dsl_user_op
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from ..typing import Pointer, Int, Int32, Numeric, NumericMeta
SM100_TMEM_CAPACITY_COLUMNS = 512
SM100_TMEM_MIN_ALLOC_COLUMNS = 32
@dsl_user_op
def retrieve_tmem_ptr(
element_type: Type[Numeric],
alignment: int,
ptr_to_buffer_holding_addr: Pointer,
*,
loc=None,
ip=None,
) -> Pointer:
"""
Retrieves a pointer to TMEM with the provided element type and alignment.
:param element_type: The pointee type of the pointer.
:type element_type: Type[Numeric]
:param alignment: The alignment of the result pointer
:type alignment: int
:param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the
start of the allocation allocation
:type ptr_to_buffer_holding_addr: Pointer
:return: A pointer to TMEM
:rtype: Pointer
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(
f"element_type must be a type of Numeric, but got {element_type}"
)
res_ty = _cute_ir.PtrType.get(
element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment
)
return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr(
res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip
)
@dsl_user_op
def alloc_tmem(
num_columns: Int,
smem_ptr_to_write_address: Pointer,
is_two_cta=None,
*,
loc=None,
ip=None,
) -> None:
"""
Allocates TMEM.
:param num_columns: The number of TMEM columns to allocate
:type num_columns: Int
:param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written
to
:type smem_ptr_to_write_address: Pointer
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
"""
if isinstance(num_columns, int):
if (
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
or not (num_columns & (num_columns - 1) == 0)
):
raise ValueError(
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
)
_cute_nvgpu_ir.arch_sm100_alloc_tmem(
Int32(num_columns).ir_value(loc=loc, ip=ip),
smem_ptr_to_write_address.value,
is_two_cta=is_two_cta,
loc=loc,
ip=ip,
)
@dsl_user_op
def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None:
"""
Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can
allocate.
"""
_cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit(
is_two_cta=is_two_cta, loc=loc, ip=ip
)
@dsl_user_op
def dealloc_tmem(
tmem_ptr: Pointer,
num_columns: Int,
is_two_cta=None,
*,
loc=None,
ip=None,
) -> None:
"""
Deallocates TMEM using the provided pointer and number of columns.
:param tmem_ptr: A pointer to the TMEM allocation to de-allocate
:type tmem_ptr: Pointer
:param num_columns: The number of columns in the TMEM allocation
:type num_columns: Int
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
"""
if isinstance(num_columns, int):
if (
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
or not (num_columns & (num_columns - 1) == 0)
):
raise ValueError(
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
)
_cute_nvgpu_ir.arch_sm100_dealloc_tmem(
tmem_ptr.value,
Int32(num_columns).ir_value(loc=loc, ip=ip),
is_two_cta=is_two_cta,
loc=loc,
ip=ip,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,354 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .core import TensorSSA
from cutlass._mlir.dialects import math, arith
def acos(a: TensorSSA) -> TensorSSA:
"""Compute element-wise arc cosine of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:return: Tensor containing the arc cosine of each element in input tensor
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = acos(y) # Compute arc cosine
"""
return TensorSSA(math.acos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def asin(a: TensorSSA) -> TensorSSA:
"""Compute element-wise arc sine of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:return: Tensor containing the arc sine of each element in input tensor
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = asin(y) # Compute arc sine
"""
return TensorSSA(math.asin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise arc tangent of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc tangent of each element in input tensor
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = atan(y) # Compute arc tangent
"""
raise NotImplementedError("atan is not implemented")
return TensorSSA(math.atan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise arc tangent of two tensors.
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
between the positive x-axis and the point given by the coordinates (b, a).
:param a: First input tensor (y-coordinates)
:type a: TensorSSA
:param b: Second input tensor (x-coordinates)
:type b: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc tangent of a/b element-wise
:rtype: TensorSSA
Example:
.. code-block::
y = cute.make_fragment(ptr1, layout).load() # y coordinates
x = cute.make_fragment(ptr2, layout).load() # x coordinates
theta = atan2(y, x) # Compute angles
"""
return TensorSSA(
math.atan2(a, b, fastmath=arith.FastMathFlags.none), a.shape, a.dtype
)
def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise cosine of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the cosine of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = cos(y) # Compute cosine
"""
return TensorSSA(math.cos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise error function of the input tensor.
The error function is defined as:
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the error function value for each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = erf(y) # Compute error function
"""
return TensorSSA(math.erf(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise base-2 exponential of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing 2 raised to the power of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = exp2(y) # Compute 2^x
"""
return TensorSSA(math.exp2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise natural logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the natural logarithm of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = log(y) # Compute natural logarithm
"""
return TensorSSA(math.log(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise base-2 logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the base-2 logarithm of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = log2(y) # Compute log base 2
"""
return TensorSSA(math.log2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise base-10 logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the base-10 logarithm of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = log10(y) # Compute log base 10
"""
return TensorSSA(math.log10(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise reciprocal square root of the input tensor.
Computes 1/√x element-wise.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the reciprocal square root of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = rsqrt(y) # Compute 1/√x
"""
return TensorSSA(math.rsqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise sine of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the sine of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = sin(y) # Compute sine
"""
return TensorSSA(math.sin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise square root of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the square root of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = sqrt(y) # Compute square root
"""
return TensorSSA(math.sqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def tan(a: TensorSSA) -> TensorSSA:
"""Compute element-wise tangent of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:return: Tensor containing the tangent of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = tan(y) # Compute tangent
"""
return TensorSSA(math.tan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
"""Compute element-wise hyperbolic tangent of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the hyperbolic tangent of each element
:rtype: TensorSSA
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = tanh(y) # Compute hyperbolic tangent
"""
return TensorSSA(math.tanh(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
__all__ = [
"acos",
"asin",
"atan",
"atan2",
"cos",
"erf",
"exp2",
"log",
"log10",
"log2",
"rsqrt",
"sin",
"sqrt",
"tan",
"tanh",
]

View File

@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from . import warp
from . import cpasync
from . import warpgroup
from . import tcgen05
from .common import *
from .helpers import *
# __all__ is required here for documentation generation
__all__ = [
"OpError",
"MmaUniversalOp",
"CopyUniversalOp",
]

View File

@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass
from typing import Type, Optional
from cutlass.cutlass_dsl import DSLBaseError
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from .. import core
from ..typing import Float16, Float32, Float64, Numeric
class OpError(DSLBaseError):
"""
An exception class for Op construction errors.
"""
def __init__(
self, op: core.Op, message: str, suggestion: Optional[str] = None
) -> None:
if suggestion is None:
# Default suggestion
suggestion = "Check your Op construction code"
super().__init__(
message,
error_code=f"{op.__class__.__name__} error",
suggestion=suggestion,
)
####################################################################################################
#
# MMA Ops and Traits
#
####################################################################################################
@dataclass(frozen=True)
class MmaUniversalOp(core.MmaOp):
"""
The universal MMA Operation.
This Operation currently expects the A/B operands as well as the accumulator to share the same
data types.
:param abacc_dtype: The data type for the A/B operands and the accumulator
:type abacc_dtype: Type[Numeric]
"""
abacc_dtype: Type[Numeric]
def __post_init__(self) -> None:
if self.abacc_dtype not in [Float16, Float32, Float64]:
raise OpError(
self,
f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64",
)
def __str__(self) -> str:
return (
"universal MMA Operation using FMA"
f"\n A/B/Accumulator data type = {self.abacc_dtype}"
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait":
shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">')
atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get(
shape_mnk_attr,
self.abacc_dtype.mlir_type,
self.abacc_dtype.mlir_type,
self.abacc_dtype.mlir_type,
)
return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
class MmaUniversalTrait(core.Trait):
pass
####################################################################################################
#
# Copy Ops and Traits
#
####################################################################################################
@dataclass(frozen=True)
class CopyUniversalOp(core.CopyOp):
"""
The universal Copy Operation.
When creating a Copy Atom out of this operation, the expected usage pattern is
.. code-block:: python
op = cute.nvgpu.CopyUniversalOp()
atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64)
- ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \
or the destination TV Layout) in unit of tensor elements and is used for partitioning by \
``TiledCopy`` for example
- ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \
execution. This can be larger than the width of the above data type. When not provided, \
the compiler will do a best effort at auto-vectorizing.
"""
def __str__(self) -> str:
return "universal Copy Operation"
def _make_trait(
self,
copy_internal_type: Type[Numeric],
*,
loc=None,
ip=None,
**kwargs,
) -> "CopyUniversalTrait":
num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
raise ValueError(
"expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
f"when creating a copy Atom for {self.__class__.__name__}"
)
ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
copy_internal_type.mlir_type, num_bits_per_copy
)
return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class CopyUniversalTrait(core.Trait):
pass

View File

@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .copy import *
from .helpers import *
# __all__ is required here for documentation generation
__all__ = [
#
# copy.py
#
"LoadCacheMode",
"CopyG2SOp",
"CopyBulkTensorTileG2SOp",
"CopyBulkTensorTileG2SMulticastOp",
"CopyBulkTensorTileS2GOp",
#
# helpers.py
#
"make_tma_tile_atom",
"tma_partition",
"create_tma_multicast_mask",
"prefetch_descriptor",
"copy_tensormap",
"update_tma_descriptor",
"fence_tma_desc_acquire",
"cp_fence_tma_desc_release",
"fence_tma_desc_release",
]

View File

@@ -0,0 +1,366 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from dataclasses import dataclass
from typing import Optional, Type
from cutlass.cutlass_dsl import CuTeDSL, t
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ...core import CopyOp, Trait
from ...typing import Int16, Pointer, Integer, Numeric
from ..common import OpError
from ..tcgen05.mma import CtaGroup
####################################################################################################
#
# Aynchronous copies
#
####################################################################################################
class LoadCacheMode(enum.Enum):
"""
An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators>`__.
"""
ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always
GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_
STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming
LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use
NONE = _cute_nvgpu_ir.LoadCacheMode.none
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode:
return self.value
@dataclass(frozen=True)
class CopyG2SOp(CopyOp):
"""
Non-bulk asynchronous GMEM to SMEM Copy Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-non-bulk-copy>`__.
"""
cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS
def __str__(self) -> str:
res = "cp.async GMEM -> SMEM copy Operation"
if self.cache_mode != LoadCacheMode.ALWAYS:
res += f"\n with cache mode = {self.cache_mode}"
return res
def _make_trait(
self,
copy_internal_type: Type[t.Numeric],
*,
loc=None,
ip=None,
**kwargs,
) -> "CopyG2STrait":
num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
raise ValueError(
"expects a 'num_bits_per_copy' kw argument of type int that is positive "
f"when creating a copy Atom for {self.__class__.__name__}"
)
# Verify that the user provided enum values
if not isinstance(self.cache_mode, LoadCacheMode):
raise OpError(
self,
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
)
ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get(
copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy
)
return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class CopyG2STrait(Trait):
pass
####################################################################################################
#
# Bulk tensor copies a.k.a TMA copies
#
####################################################################################################
TMA_MBAR_PTR_FIELD_NAME = "tma_bar"
TMA_MASK_FIELD_NAME = "mcast_mask"
TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr"
#
# TMA GMEM -> SMEM copies
#
@dataclass(frozen=True)
class CopyBulkTensorTileG2SOp(CopyOp):
"""
Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
This Operation uses TMA in the ``.tile`` mode.
"""
cta_group: CtaGroup = CtaGroup.ONE
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
def __post_init__(self) -> None:
if not isinstance(self.cta_group, CtaGroup):
raise OpError(
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
raise OpError(
self,
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
def __str__(self) -> str:
res = "cp.async GMEM -> SMEM bulk tensor copy Operation"
if self.cta_group == CtaGroup.TWO:
res += f"\n CTA group = 2"
return res
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileG2SNonExecTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
)
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
if self.cta_group == CtaGroup.ONE:
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90
elif self.cta_group == CtaGroup.TWO:
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm
else:
assert False, "unrecognized self.cta_group"
class CopyBulkTensorTileG2SNonExecTrait(Trait):
# We allow kw args to be dropped so that the user can write common code for non-multicast
# and multicast loads.
def unpack(
self,
*,
loc=None,
ip=None,
tma_bar_ptr: Optional[Pointer] = None,
tma_desc_ptr: Optional[Pointer] = None,
**kwargs,
):
"""
Custom implementation of unpack for non-executable TMAs.
The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when
using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error.
"""
if not isinstance(tma_bar_ptr, Pointer):
raise ValueError(
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
)
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>"
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
)
if isinstance(tma_desc_ptr, Pointer):
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
)
return exec_value
#
# TMA GMEM -> SMEM multicast copies
#
@dataclass(frozen=True)
class CopyBulkTensorTileG2SMulticastOp(CopyOp):
"""
Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
This Operation uses TMA in the ``.tile`` mode.
"""
cta_group: CtaGroup = CtaGroup.ONE
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
def __post_init__(self):
if not isinstance(self.cta_group, CtaGroup):
raise OpError(
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
raise OpError(
self,
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
def __str__(self) -> str:
res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation"
if self.cta_group == CtaGroup.TWO:
res += f"\n CTA group = 2"
return res
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
)
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
if self.cta_group == CtaGroup.ONE:
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast
elif self.cta_group == CtaGroup.TWO:
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast
else:
assert False, "unrecognized self.cta_group"
class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
def unpack(
self,
*,
loc=None,
ip=None,
tma_bar_ptr: Optional[Pointer] = None,
mcast_mask=None,
tma_desc_ptr=None,
):
"""
Custom implementation of unpack for non-executable TMAs.
The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be
provided when using `cute.copy`.
"""
if not isinstance(tma_bar_ptr, Pointer):
raise ValueError(
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
)
if not isinstance(mcast_mask, Integer):
raise ValueError(
"expects a multicast mask to be provided via the mcast_mask kw argument"
)
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<tma_bar>"
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
)
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<mcast_mask>"
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
if isinstance(tma_desc_ptr, Pointer):
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
)
return exec_value
#
# TMA SMEM -> GMEM copies
#
@dataclass(frozen=True)
class CopyBulkTensorTileS2GOp(CopyOp):
"""
Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
This Operation uses TMA in the ``.tile`` mode.
"""
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
def __post_init__(self):
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
def __str__(self) -> str:
return "cp.async SMEM -> GMEM bulk tensor copy Operation"
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyBulkTensorTileS2GTrait":
raise NotImplementedError(
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
)
class CopyBulkTensorTileS2GTrait(Trait):
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
"""
Custom implementation of unpack for non-executable TMAs.
"""
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
if isinstance(tma_desc_ptr, Pointer):
attr_str = (
f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>"
)
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
)
return exec_value

View File

@@ -0,0 +1,327 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional, Tuple, Type, Union
from cutlass.cutlass_dsl import dsl_user_op
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir.dialects import llvm
from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta
from ... import core
from .copy import (
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileS2GOp,
CopyBulkTensorTileG2SNonExecTrait,
CopyBulkTensorTileG2SMulticastNonExecTrait,
CopyBulkTensorTileS2GTrait,
)
@dsl_user_op
def make_tma_tile_atom(
op: Union[
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileS2GOp,
],
gmem_tensor: Tensor,
smem_layout: Layout,
cta_tiler: Tiler,
num_multicast: int = 1,
*,
internal_type: Optional[Type[Numeric]] = None,
loc=None,
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
"""
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from and SMEM
buffer with the given Layout.
Given
- a GMEM tensor
- a SMEM layout
- a CTA-level Tiler
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
"TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided
layout and consistent with the provided Tiler.
This function returns two results:
1. the Copy Atom
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \
similarly to any other CuTe tensors using the algebra.
:param op: The Copy Operation to construct an Atom for
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp]
:param gmem_tensor: The GMEM tensor involved in the Copy
:type gmem_tensor: Tensor
:param smem_layout: The SMEM layout to construct the Copy Atom for
:type smem_layout: Layout
:param cta_tiler: The CTA Tiler to use
:type cta_tiler: Tiler
:param num_multicast: The multicast factor
:type num_multicast: int
:param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit
:type internal_type: Type[Numeric]
:return: A Copy Atom for this Operation and the associated TMA tensor
:rtype: Tuple[core.CopyAtom, Tensor]
"""
if internal_type is not None:
if not isinstance(internal_type, NumericMeta):
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
internal_type = internal_type.mlir_type
cta_v_map = core.composition(
core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip),
cta_tiler,
loc=loc,
ip=ip,
)
if isinstance(op, CopyBulkTensorTileG2SOp):
if num_multicast != 1:
raise ValueError(
f"expects num_multicast to be 1 for non multicast G2S copies, "
f"but got {num_multicast}"
)
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
gmem_tensor.value,
smem_layout,
cta_v_map,
op._to_ir(),
num_multicast=num_multicast,
internal_type=internal_type,
loc=loc,
ip=ip,
)
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
elif isinstance(op, CopyBulkTensorTileG2SMulticastOp):
if num_multicast < 1:
raise ValueError(
f"expects num_multicast to be >= 1 for multicast G2S copies, "
f"but got {num_multicast}"
)
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
gmem_tensor.value,
smem_layout,
cta_v_map,
op._to_ir(),
num_multicast=num_multicast,
internal_type=internal_type,
loc=loc,
ip=ip,
)
return (
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
res[1],
)
elif isinstance(op, CopyBulkTensorTileS2GOp):
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store(
gmem_tensor.value,
smem_layout,
cta_v_map,
internal_type=internal_type,
loc=loc,
ip=ip,
)
return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
else:
raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")
@dsl_user_op
def tma_partition(
atom: core.CopyAtom,
cta_coord: Coord,
cta_layout: Layout,
smem_tensor: Tensor,
gmem_tensor: Tensor,
*,
loc=None,
ip=None,
) -> Tuple[Tensor, Tensor]:
"""
Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom.
"""
cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip)
s, d = _cute_nvgpu_ir.atom_tma_partition(
atom._trait.value,
cta_coord=cta_coord_val,
cta_layout=cta_layout,
smem_tensor=smem_tensor.value,
gmem_tensor=gmem_tensor.value,
loc=loc,
ip=ip,
)
return s, d
@dsl_user_op
def create_tma_multicast_mask(
cta_layout_vmnk: Layout,
cta_coord_vmnk: Coord,
mcast_mode: int,
*,
loc=None,
ip=None,
) -> Int16:
"""
Computes a multicast mask for a TMA load Copy.
:param cta_layout_vmnk: The VMNK layout of the cluster
:type cta_layout_vmnk: Layout
:param cta_coord_vmnk: The VMNK coordinate of the current CTA
:type cta_coord_vmnk: Coord
:param mcast_mode: The tensor mode in which to multicast
:type mcast_mode: int
:return: The resulting mask
:rtype: Int16
"""
if core.rank(cta_layout_vmnk) != 4:
raise ValueError(
f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}"
)
if core.rank(cta_coord_vmnk) != 4:
raise ValueError(
f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}"
)
return core.make_layout_image_mask(
cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip
)
@dsl_user_op
def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None:
"""
Prefetches the TMA descriptor associated with the TMA Atom.
"""
_cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip)
@dsl_user_op
def copy_tensormap(
tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None
) -> None:
"""
Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided
pointer.
:param tma_atom: The TMA Copy Atom
:type tma_atom: CopyAtom
:param tensormap_ptr: The pointer to the memory location to copy the tensormap to
:type tensormap_ptr: Pointer
"""
_cute_nvgpu_ir.copy_tma_desc(
tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip
)
@dsl_user_op
def update_tma_descriptor(
tma_atom: core.CopyAtom,
gmem_tensor: Tensor,
tma_desc_ptr: Pointer,
*,
loc=None,
ip=None,
) -> None:
"""
Updates the TMA descriptor in the memory location pointed to by the provided pointer using
information from a TMA Copy Atom and the provided GMEM tensor.
Specifically, the following fields of the TMA descriptor will be updated:
1. the GMEM tensor base address
2. the GMEM tensor shape
3. the GMEM tensor stride
Other fields of the TMA descriptor are left unchanged.
:param tma_atom: The TMA Copy Atom
:type tma_atom: CopyAtom
:param gmem_tensor: The GMEM tensor
:type gmem_tensor: Tensor
:param tensormap_ptr: The pointer to the memory location of the descriptor to udpate
:type tensormap_ptr: Pointer
"""
_cute_nvgpu_ir.update_tma_desc(
tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip
)
@dsl_user_op
def fence_tma_desc_acquire(
tma_desc_ptr: Pointer,
*,
loc=None,
ip=None,
) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
[tma_desc_ptr_i64],
"fence.proxy.tensormap::generic.acquire.gpu [$0], 128;",
"l",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@dsl_user_op
def cp_fence_tma_desc_release(
tma_desc_global_ptr: Pointer,
tma_desc_shared_ptr: Pointer,
*,
loc=None,
ip=None,
) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy>`__.
"""
tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value()
tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
[tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32],
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;",
"l,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@dsl_user_op
def fence_tma_desc_release(*, loc=None, ip=None) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
"""
llvm.inline_asm(
None,
[],
"fence.proxy.tensormap::generic.release.gpu;",
"",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)

View File

@@ -0,0 +1,159 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional, Tuple, Type, Union
from cutlass.cutlass_dsl import dsl_user_op
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from .. import core
from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta
from ...impl_utils import check_type_in
from .cpasync.copy import (
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SNonExecTrait,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileG2SMulticastNonExecTrait,
)
####################################################################################################
#
# TMA creation helpers for tcgen05 MMAs
#
####################################################################################################
@dsl_user_op
def make_tma_tile_atom_A(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
mma_tiler_mnk: Shape,
tiled_mma: core.TiledMma,
cluster_shape_vmnk: Shape,
*,
internal_type: Optional[Type[Numeric]] = None,
loc=None,
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
if internal_type is not None:
if not isinstance(internal_type, NumericMeta):
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
internal_type = internal_type.mlir_type
check_type_in(
op,
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
"op",
"make_tma_tile_atom_A",
)
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:])
g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip)
cta_v_map = tiled_mma._thrfrg_A(g_tile)
cta_v_map = core.get(cta_v_map, mode=[1])
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
if isinstance(op, CopyBulkTensorTileG2SOp):
num_multicast = 1
else:
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
# multicast across the N-mode since those would share the same tile of A
num_multicast = core.size(cluster_shape_vmnk, mode=[2])
# res[0] = the IR Value for the non-executable atom instance
# res[1] = the IR Value for the associated TMA tensor
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
gmem_tensor.value,
smem_layout,
cta_v_map,
op._to_ir(),
num_multicast=num_multicast,
internal_type=internal_type,
loc=loc,
ip=ip,
)
if isinstance(op, CopyBulkTensorTileG2SOp):
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
else:
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
return (
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
res[1],
)
@dsl_user_op
def make_tma_tile_atom_B(
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
gmem_tensor: Tensor,
smem_layout: Layout,
mma_tiler_mnk: Shape,
tiled_mma: core.TiledMma,
cluster_shape_vmnk: Shape,
*,
internal_type: Optional[Type[Numeric]] = None,
loc=None,
ip=None,
) -> Tuple[core.CopyAtom, Tensor]:
if internal_type is not None:
if not isinstance(internal_type, NumericMeta):
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
internal_type = internal_type.mlir_type
check_type_in(
op,
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
"op",
"make_tma_tile_atom_B",
)
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:])
g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip)
cta_v_map = tiled_mma._thrfrg_B(g_tile)
cta_v_map = core.get(cta_v_map, mode=[1])
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
if isinstance(op, CopyBulkTensorTileG2SOp):
num_multicast = 1
else:
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
# multicast across the M-mode since those would share the same tile of B
num_multicast = core.size(cluster_shape_vmnk, mode=[1])
# res[0] = the IR Value for the non-executable atom instance
# res[1] = the IR Value for the associated TMA tensor
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
gmem_tensor.value,
smem_layout,
cta_v_map,
op._to_ir(),
num_multicast=num_multicast,
internal_type=internal_type,
loc=loc,
ip=ip,
)
if isinstance(op, CopyBulkTensorTileG2SOp):
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
else:
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
return (
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
res[1],
)
__all__ = [
"make_tma_tile_atom_A",
"make_tma_tile_atom_B",
]

View File

@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .copy import *
from .mma import *
from .helpers import *
# __all__ is required here for documentation generation
__all__ = [
#
# copy.py
#
"Repetition",
"Pack",
"Unpack",
"Ld16x64bOp",
"Ld16x128bOp",
"Ld16x256bOp",
"Ld16x32bx2Op",
"Ld32x32bOp",
"St16x64bOp",
"St16x128bOp",
"St16x256bOp",
"St16x32bx2Op",
"St32x32bOp",
#
# mma.py
#
"OperandMajorMode",
"OperandSource",
"CtaGroup",
"Field",
"MmaTF32Op",
"MmaF16BF16Op",
"MmaI8Op",
"MmaFP8Op",
"SmemLayoutAtomKind",
#
# helpers.py
#
"make_smem_layout_atom",
"tile_to_mma_shape",
"commit",
"is_tmem_load",
"is_tmem_store",
"get_tmem_copy_properties",
"find_tmem_tensor_col_offset",
"make_tmem_copy",
]

View File

@@ -0,0 +1,465 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from dataclasses import dataclass
from typing import Type
from cutlass.cutlass_dsl import CuTeDSL
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import CopyOp, Trait
from ...typing import Numeric
class Repetition(enum.Enum):
"""
An enumeration for the number of repetitions of a given TMEM copy within the instruction.
"""
x1 = 1
x2 = 2
x4 = 4
x8 = 8
x16 = 16
x32 = 32
x64 = 64
x128 = 128
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
@classmethod
def _missing_(cls, value):
if isinstance(value, int):
if value == 1:
return Repetition.x1
elif value == 2:
return Repetition.x2
elif value == 8:
return Repetition.x8
elif value == 16:
return Repetition.x16
elif value == 32:
return Repetition.x32
elif value == 64:
return Repetition.x64
elif value == 128:
return Repetition.x128
class Pack(enum.Enum):
"""
An enumeration for the possible packing patterns for TMEM to RMEM copies.
"""
NONE = enum.auto()
PACK_16b_IN_32b = enum.auto()
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
class Unpack(enum.Enum):
"""
An enumeration for the possible unpacking patterns for RMEM to TMEM copies.
"""
NONE = enum.auto()
UNPACK_32b_IN_16b = enum.auto()
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
@dataclass(frozen=True)
class _LdBase(CopyOp):
repeat: Repetition = Repetition.x1
pack: Pack = Pack.NONE
admissible_archs = ["sm_100a"]
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if not isinstance(self.repeat, Repetition):
raise OpError(
self,
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
)
if not isinstance(self.pack, Pack):
raise OpError(
self,
"expects the 'pack' Op parameter to be a tcgen05.Pack instance",
)
def __str__(self) -> str:
res = (
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
+ f"\n number of repetitions = {self.repeat.value}"
)
if self.pack == Pack.PACK_16b_IN_32b:
res += f"\n with 2x 16-bit to 32b packing"
return res
@dataclass(frozen=True)
class Ld16x64bOp(_LdBase):
"""
16x64b TMEM load Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.16x64b`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Ld16x64bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
copy_internal_type.mlir_type,
16,
64,
self.repeat.value,
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
)
return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Ld16x64bTrait(Trait):
pass
@dataclass(frozen=True)
class Ld16x128bOp(_LdBase):
"""
16x128b TMEM load Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.16x128b`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.repeat == Repetition.x128:
raise OpError(
self,
"x128 repetition is not supported",
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Ld16x128bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
copy_internal_type.mlir_type,
16,
128,
self.repeat.value,
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
)
return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Ld16x128bTrait(Trait):
pass
@dataclass(frozen=True)
class Ld16x256bOp(_LdBase):
"""
16x256b TMEM load Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.16x256b`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.repeat in (Repetition.x128, Repetition.x64):
raise OpError(
self,
"x64 and x128 repetition is not supported",
suggestion="choose one of x1, x2, x4, x8, x16, x32",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Ld16x256bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
copy_internal_type.mlir_type,
16,
256,
self.repeat.value,
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
)
return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Ld16x256bTrait(Trait):
pass
@dataclass(frozen=True)
class Ld16x32bx2Op(_LdBase):
"""
16x32bx2 TMEM load Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.16x32bx2`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Ld16x32bx2Trait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
copy_internal_type.mlir_type,
16,
32,
self.repeat.value,
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
)
return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Ld16x32bx2Trait(Trait):
pass
@dataclass(frozen=True)
class Ld32x32bOp(_LdBase):
"""
32x32b TMEM load Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.32x32`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "Ld32x32bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
copy_internal_type.mlir_type,
32,
32,
self.repeat.value,
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
)
return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class Ld32x32bTrait(Trait):
pass
@dataclass(frozen=True)
class _StBase(CopyOp):
repeat: Repetition
unpack: Unpack = Unpack.NONE
admissible_archs = ["sm_100a"]
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if not isinstance(self.repeat, Repetition):
raise OpError(
self,
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
)
if not isinstance(self.unpack, Unpack):
raise OpError(
self,
"expects the 'pack' Op parameter to be a tcgen05.Unpack instance",
)
def __str__(self) -> str:
res = (
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
+ f"\n number of repetitions = {self.repeat.value}"
)
if self.unpack == Unpack.UNPACK_32b_IN_16b:
res += f"\n with 32-bit to 2x 16b unpacking"
return res
@dataclass(frozen=True)
class St16x64bOp(_StBase):
"""
16x64b TMEM store Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
This Operation corresponds to the ``.16x64`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "St16x64bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
copy_internal_type.mlir_type,
16,
64,
self.repeat.value,
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
)
return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class St16x64bTrait(Trait):
pass
@dataclass(frozen=True)
class St16x128bOp(_StBase):
"""
16x128b TMEM store Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
This Operation corresponds to the ``.16x128`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.repeat == Repetition.x128:
raise OpError(
self,
"x128 repetition is not supported",
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "St16x128bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
copy_internal_type.mlir_type,
16,
128,
self.repeat.value,
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
)
return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class St16x128bTrait(Trait):
pass
@dataclass(frozen=True)
class St16x256bOp(_StBase):
"""
16x256b TMEM store Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
This Operation corresponds to the ``.16x256`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.repeat in (Repetition.x128, Repetition.x64):
raise OpError(
self,
"x64 and x128 repetition is not supported",
suggestion="choose one of x1, x2, x4, x8, x16, x32",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "St16x256bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
copy_internal_type.mlir_type,
16,
256,
self.repeat.value,
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
)
return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class St16x256bTrait(Trait):
pass
@dataclass(frozen=True)
class St16x32bx2Op(_StBase):
"""
16x32x2b TMEM store Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
This Operation corresponds to the ``.16x32x2`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "St16x32bx2Trait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
copy_internal_type.mlir_type,
16,
32,
self.repeat.value,
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
)
return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
class St16x32bx2Trait(Trait):
pass
@dataclass(frozen=True)
class St32x32bOp(_StBase):
"""
32x32b TMEM store Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
This Operation corresponds to the ``.32x32`` qualifier.
"""
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "St32x32bTrait":
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
copy_internal_type.mlir_type,
32,
32,
self.repeat.value,
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
)
return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class St32x32bTrait(Trait):
pass

View File

@@ -0,0 +1,301 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import overload, Type, Tuple, Union
from cutlass.cutlass_dsl import dsl_user_op
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir.dialects import nvvm
from ...typing import (
Shape,
IntTuple,
Layout,
Tensor,
Int,
Numeric,
NumericMeta,
Int16,
Int32,
)
from ... import core
from .mma import SmemLayoutAtomKind, CtaGroup
from .copy import (
Pack,
Unpack,
Ld16x64bOp,
Ld16x128bOp,
Ld16x256bOp,
Ld16x32bx2Op,
Ld32x32bOp,
St16x64bOp,
St16x128bOp,
St16x256bOp,
St16x32bx2Op,
St32x32bOp,
)
####################################################################################################
#
# Helper functions for MMA
#
####################################################################################################
@dsl_user_op
def make_smem_layout_atom(
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
) -> core.ComposedLayout:
"""
Makes a SMEM layout Atom.
This function creates a composed layout in unit of elements consistent with the requested layout
Atom kind and element data type.
:param kind: The kind of layout Atom
:type kind: SmemLayoutAtomKind
:param element_type: The element data type to construct the layout for
:type element_type: Type[Numeric]
:return: The SMEM layout atom
:rtype: core.ComposedLayout
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
num_contiguous_bits = 128
sw = core.make_swizzle(0, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
num_contiguous_bits = 256
sw = core.make_swizzle(1, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
num_contiguous_bits = 512
sw = core.make_swizzle(2, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
num_contiguous_bits = 1024
sw = core.make_swizzle(3, 4, 3)
elif kind == SmemLayoutAtomKind.MN_SW128_32B:
num_contiguous_bits = 1024
sw = core.make_swizzle(2, 5, 2)
else:
raise ValueError("unrecognized SMEM layout atom kind")
num_contiguous_elems = num_contiguous_bits // element_type.width
if kind in (
SmemLayoutAtomKind.MN_INTER,
SmemLayoutAtomKind.MN_SW32,
SmemLayoutAtomKind.MN_SW64,
SmemLayoutAtomKind.MN_SW128,
SmemLayoutAtomKind.MN_SW128_32B,
):
# M/N-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
),
loc=loc,
ip=ip,
)
else:
# K-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
),
loc=loc,
ip=ip,
)
@overload
def tile_to_mma_shape(
atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
) -> Layout: ...
@overload
def tile_to_mma_shape(
atom: core.ComposedLayout,
mma_tile_shape: Shape,
order: IntTuple = None,
*,
loc=None,
ip=None,
) -> core.ComposedLayout: ...
@dsl_user_op
def tile_to_mma_shape(
atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
):
"""
Tiles a layout to an MMA shape.
"""
# Default order is colexicographical
if order is None:
order = tuple(range(core.rank(mma_tile_shape) - 1))
if core.rank(order) != core.rank(mma_tile_shape) - 1:
raise ValueError(
f"rank(order)={core.rank(order)} must be equal to "
f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}"
)
order_val = core._pack_int_tuple(order, loc=loc, ip=ip)
mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip)
if not (
core.is_static(atom)
and core.is_static(mma_tile_shape_val)
and core.is_static(order_val)
):
raise ValueError("tile_to_mma_shape only supports static inputs")
res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val)
return _cute_ir.static(res_ty, loc=loc, ip=ip)
@dsl_user_op
def commit(
mbar_ptr: core.Pointer,
mask=None,
cta_group: CtaGroup = CtaGroup.ONE,
*,
loc=None,
ip=None,
) -> None:
"""
Perform an arrive operation on a mbarrier upon completion of previous MMA operations.
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
:param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to
:type mask: Int
"""
if cta_group == CtaGroup.ONE:
group = nvvm.Tcgen05GroupKind.CTA_1
else:
assert cta_group == CtaGroup.TWO
group = nvvm.Tcgen05GroupKind.CTA_2
mbar_ptr = mbar_ptr.llvm_ptr
if mask is not None:
mask = Int16(mask).ir_value(loc=loc, ip=ip)
nvvm.tcgen05_commit_arrive(
mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip
)
else:
nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip)
return
####################################################################################################
#
# Helper functions for Copies
#
####################################################################################################
def is_tmem_load(atom: core.CopyAtom) -> bool:
"""
Returns whether a CopyAtom instance is a TMEM load.
"""
return isinstance(
atom.op,
(
Ld16x64bOp,
Ld16x128bOp,
Ld16x256bOp,
Ld16x32bx2Op,
Ld32x32bOp,
),
)
def is_tmem_store(atom: core.CopyAtom) -> bool:
"""
Returns whether a CopyAtom instance is a TMEM store.
"""
return isinstance(
atom.op,
(
St16x64bOp,
St16x128bOp,
St16x256bOp,
St16x32bx2Op,
St32x32bOp,
),
)
def get_tmem_copy_properties(
atom: core.CopyAtom,
) -> Tuple[int, int, int, Union[Pack, Unpack]]:
"""
Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions,
and whether packing/unpacking is used).
"""
if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)):
num_dp, num_bits = 16, 64
elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)):
num_dp, num_bits = 16, 128
elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)):
num_dp, num_bits = 16, 256
elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)):
num_dp, num_bits = 16, 32
elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)):
num_dp, num_bits = 32, 32
else:
raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}")
if is_tmem_load(atom):
return num_dp, num_bits, atom.op.repeat.value, atom.op.pack
else:
assert is_tmem_store(atom), "atom must be a TMEM store"
return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack
@dsl_user_op
def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int:
"""
Computes the TMEM column offset given a TMEM tensor.
:param tmem_tensor: The TMEM tensor to use to compute the columns offset
:type tmem_tensor: Tensor
:return: The columns offset
:rtype: Int
"""
tmem_col_mask = 0x0000FFFF
offset = (
core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip)
& tmem_col_mask
)
if isinstance(offset, int):
return offset
return Int32(offset, loc=loc, ip=ip)
@dsl_user_op
def make_tmem_copy(
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
) -> core.TiledCopy:
"""
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
"""
tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy(
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
)
new_trait = type(atom._trait)(tiled_copy_val)
return core.TiledCopy(atom.op, new_trait)

View File

@@ -0,0 +1,603 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from dataclasses import dataclass
from typing import Type
from cutlass.cutlass_dsl import CuTeDSL, T
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape, rank, depth
from ...typing import (
Shape,
Float8E5M2,
Float8E4M3FN,
Float16,
BFloat16,
Float32,
TFloat32,
Boolean,
Int8,
Uint8,
Int32,
Numeric,
)
####################################################################################################
#
# MMA Ops and Traits
#
####################################################################################################
class OperandMajorMode(enum.Enum):
"""
An enumeration for the majorness of the input operands of the MMA.
"""
MN = _cute_ir.MajorMode.mn
K = _cute_ir.MajorMode.k
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
@classmethod
def _missing_(cls, value):
if isinstance(value, str):
value = value.upper()
if value == "MN":
return OperandMajorMode.MN
elif value == "K":
return OperandMajorMode.K
def _to_ir(self) -> _cute_ir.MajorMode:
return self.value
class OperandSource(enum.Enum):
"""
An enumeration for the source memory location of the A input operand of the MMA.
"""
TMEM = _cute_ir.MmaFragKind.tmem
SMEM = _cute_ir.MmaFragKind.smem_desc
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir(self) -> _cute_ir.MmaFragKind:
return self.value
class CtaGroup(enum.Enum):
"""
An enumeration for the ``cta_group`` qualifier of the MMA.
"""
ONE = 1
TWO = 2
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
class Field(enum.Enum):
"""
An enumeration for the fields of the MMA Atom that can be modified at runtime.
"""
NEGATE_A = "neg_a"
NEGATE_B = "neg_b"
ACCUMULATE = "accum_c"
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir_field_name(self) -> str:
return self.value
# Base class for all tcgen05 MMA Ops used to factor out some internal code
@dataclass(frozen=True)
class MmaOp(MmaOp):
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
shape_mnk: Shape
cta_group: CtaGroup
a_src: OperandSource
a_major_mode: OperandMajorMode
b_major_mode: OperandMajorMode
admissible_archs = ["sm_100a"]
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
# Verify that the user provided enum values
if not isinstance(self.cta_group, CtaGroup):
raise OpError(
self,
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
)
if not isinstance(self.a_src, OperandSource):
raise OpError(
self,
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
)
if not isinstance(self.a_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
)
if not isinstance(self.b_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
)
# Verify the instruction shape
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
raise OpError(
self,
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
f"but got {self.shape_mnk}",
)
m, n = self.shape_mnk[0], self.shape_mnk[1]
if self.cta_group == CtaGroup.ONE:
if m not in [64, 128]:
raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
if m == 64:
if (n < 8) or (n > 256) or (n % 8 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
)
elif m == 128:
if (n < 16) or (n > 256) or (n % 16 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}",
)
else:
if m not in [128, 256]:
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
if (n < 32) or (n > 256) or (n % 32 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
)
def __str__(self) -> str:
return (
self.__class__.descriptive_name # type: ignore
+ f"\n A data type = {self.a_dtype}"
+ f"\n B data type = {self.b_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n CTA group = {self.cta_group}"
+ f"\n A source location = {self.a_src}"
+ f"\n A major mode = {self.a_major_mode}"
+ f"\n B major mode = {self.b_major_mode}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
class MmaTrait(Trait):
admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]
def set(self, field, value, *, loc=None, ip=None) -> None:
if field not in self.admissible_fields:
raise ValueError(
f"expects field to be one of {self.admissible_fields}, but got {field}"
)
field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>"
attr = ir.Attribute.parse(field_name)
self.value = _cute_nvgpu_ir.atom_set_value(
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
#
# TF32 MMA
#
@dataclass(frozen=True)
class MmaTF32Op(MmaOp):
"""
TF32 tcgen05 MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::tf32`` qualifier.
"""
descriptive_name = "tcgen05 TF32 MMA Operation"
def __init__(
self,
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
TFloat32,
TFloat32,
Float32,
instruction_shape,
cta_group,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Verify the instruction shape
instruction_k = 8
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.a_src._to_ir(),
0,
)
return MmaTF32Trait(
_cute_nvgpu_ir.make_sm100_mma(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
class MmaTF32Trait(MmaTrait):
pass
#
# F16/BF16 MMA
#
@dataclass(frozen=True)
class MmaF16BF16Op(MmaOp):
"""
F16/BF16 tcgen05 MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::f16`` qualifier.
"""
descriptive_name = "tcgen05 F16/BF16 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
ab_dtype,
ab_dtype,
acc_dtype,
instruction_shape,
cta_group,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Input data type verification
if self.a_dtype not in [Float16, BFloat16]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
)
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
# Accumulator data type verification
if self.acc_dtype not in [Float16, Float32]:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
)
# Instruction shape verification
instruction_k = 16
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.a_src._to_ir(),
0,
)
return MmaF16BF16Trait(
_cute_nvgpu_ir.make_sm100_mma(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
class MmaF16BF16Trait(MmaTrait):
pass
#
# I8 MMA
#
@dataclass(frozen=True)
class MmaI8Op(MmaOp):
"""
I8 tcgen05 MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
This Operation corresponds to the ``.kind::i8`` qualifier.
"""
descriptive_name = "tcgen05 I8 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
ab_dtype,
ab_dtype,
Int32,
instruction_shape,
cta_group,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Input data type verification
if self.a_dtype not in [Int8, Uint8]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8",
)
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
# Instruction shape verification
instruction_k = 32
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
(T.si8() if self.a_dtype.signed else T.ui8()),
(T.si8() if self.b_dtype.signed else T.ui8()),
T.si32(),
self.a_src._to_ir(),
0,
)
return MmaI8Trait(
_cute_nvgpu_ir.make_sm100_mma(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
class MmaI8Trait(MmaTrait):
pass
#
# F8F6F4 MMA
#
@dataclass(frozen=True)
class MmaFP8Op(MmaOp):
"""
F8 tcgen05 MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
"""
descriptive_name = "tcgen05 F8 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
instruction_shape: Shape,
cta_group: CtaGroup,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
ab_dtype,
ab_dtype,
acc_dtype,
instruction_shape,
cta_group,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Input data type verification
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
)
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
# Accumulator data type verification
if self.acc_dtype not in [Float16, Float32]:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
)
# Instruction shape verification
instruction_k = 32
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
shape_mnk.type.attribute,
self.cta_group.value,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.a_src._to_ir(),
0,
)
return MmaFP8Trait(
_cute_nvgpu_ir.make_sm100_mma(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
Boolean(False).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
class MmaFP8Trait(MmaTrait):
pass
####################################################################################################
#
# SMEM layout atoms
#
####################################################################################################
class SmemLayoutAtomKind(enum.Enum):
"""
Enum class for the kinds of SMEM layout atoms for SM100.
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be
used to construct an SMEM layout using blocked product for operand A or B such that the
resulting layout is legal for both TMA and UMMA.
Note that there are other ways of creating legal layouts for operand A and B.
"""
MN_INTER = enum.auto()
MN_SW32 = enum.auto()
MN_SW64 = enum.auto()
MN_SW128 = enum.auto()
MN_SW128_32B = enum.auto()
K_INTER = enum.auto()
K_SW32 = enum.auto()
K_SW64 = enum.auto()
K_SW128 = enum.auto()

View File

@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .copy import *
from .mma import *
# __all__ is required here for documentation generation
__all__ = [
# mma.py
"MmaF16BF16Op",
# copy.py
"LdMatrix8x8x16bOp",
"LdMatrix16x16x8bOp",
"StMatrix8x8x16bOp",
"StMatrix16x8x8bOp",
]

View File

@@ -0,0 +1,189 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass
from typing import Type
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import CopyOp, Trait, _pack_shape
from ...typing import Numeric
@dataclass(frozen=True)
class BaseOp(CopyOp):
transpose: bool = False
num_matrices: int = 1
def __post_init__(self) -> None:
if not isinstance(self.transpose, bool):
raise OpError(
self,
"expects the 'transpose' Op parameter to be a bool instance",
)
def __str__(self) -> str:
res = (
f"{self.__class__.__name__[:-2]} Copy Operation"
+ f"\n number of matrices = {self.num_matrices}"
)
if self.transpose:
res += f"\n transposed"
return res
@dataclass(frozen=True)
class LdMatrix8x8x16bOp(BaseOp):
"""
8x8 ``ldmatrix`` Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
This operation corresponds to the ``.m8n8`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.num_matrices not in [1, 2, 4]:
raise OpError(
self,
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdMatrix8x8x16bTrait":
mode = _pack_shape((8, 8), loc=loc, ip=ip)
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
copy_internal_type.mlir_type,
mode.type.attribute,
_cute_nvgpu_ir.LdsmSzPattern.u16,
self.num_matrices,
ir.UnitAttr.get() if self.transpose else None,
)
return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class LdMatrix8x8x16bTrait(Trait):
pass
@dataclass(frozen=True)
class LdMatrix16x16x8bOp(BaseOp):
"""
16x16 8-bit ``ldmatrix`` Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers.
"""
def __init__(self, num_matrices: int) -> None:
super().__init__(transpose=True, num_matrices=num_matrices)
self._verify()
def _verify(self):
assert self.transpose, "transpose must be True"
if self.num_matrices not in [1, 2]:
raise OpError(
self,
"expects the 'num_matrices' Op parameter to be one of [1,2]",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdMatrix16x16x8bTrait":
mode = _pack_shape((16, 16), loc=loc, ip=ip)
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
copy_internal_type.mlir_type,
mode.type.attribute,
_cute_nvgpu_ir.LdsmSzPattern.u8,
self.num_matrices,
ir.UnitAttr.get(),
)
return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class LdMatrix16x16x8bTrait(Trait):
pass
@dataclass(frozen=True)
class StMatrix8x8x16bOp(BaseOp):
"""
8x8 ``stmatrix`` Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
This operation corresponds to the ``m8n8`` qualifier.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.num_matrices not in [1, 2, 4]:
raise OpError(
self,
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "StMatrix8x8x16bTrait":
mode = _pack_shape((8, 8), loc=loc, ip=ip)
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
copy_internal_type.mlir_type,
mode.type.attribute,
self.num_matrices,
ir.UnitAttr.get() if self.transpose else None,
)
return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class StMatrix8x8x16bTrait(Trait):
pass
@dataclass(frozen=True)
class StMatrix16x8x8bOp(BaseOp):
"""
16x8 ``stmatrix`` Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
This operation corresponds to the ``m16n8`` qualifier.
"""
def __init__(self, num_matrices: int) -> None:
super().__init__(transpose=True, num_matrices=num_matrices)
self._verify()
def _verify(self):
if self.num_matrices not in [1, 2, 4]:
assert self.transpose, "transpose must be True"
raise OpError(
self,
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "StMatrix16x8x8bTrait":
mode = _pack_shape((16, 8), loc=loc, ip=ip)
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
copy_internal_type.mlir_type,
mode.type.attribute,
self.num_matrices,
ir.UnitAttr.get(),
)
return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
class StMatrix16x8x8bTrait(Trait):
pass

View File

@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass
from typing import Type
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape
from ...typing import Shape, Float16, BFloat16, Float32, Numeric
@dataclass(frozen=True)
class MmaF16BF16Op(MmaOp):
"""
F16/BF16 tcgen05 MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
"""
ab_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
shape_mnk: Shape
def __post_init__(self) -> None:
if self.ab_dtype not in [Float16, BFloat16]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
)
if self.acc_dtype not in [Float16, Float32]:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
)
if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32):
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
)
if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]:
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM80Type.get(
shape_mnk.type.attribute,
self.ab_dtype.mlir_type,
self.ab_dtype.mlir_type,
self.acc_dtype.mlir_type,
)
return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
def __str__(self) -> str:
return (
"warp-level F16/BF16 MMA Operation"
+ f"\n A/B data type = {self.ab_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
class MmaF16BF16Trait(Trait):
pass

View File

@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .mma import *
from .helpers import *
# __all__ is required here for documentation generation
__all__ = [
# mma.py
"OperandMajorMode",
"OperandSource",
"Field",
"MmaF16BF16Op",
"MmaF8Op",
"SmemLayoutAtomKind",
# helpers.py
"make_smem_layout_atom",
"fence",
"commit_group",
"wait_group",
]

View File

@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Type
from cutlass.cutlass_dsl import dsl_user_op
from cutlass._mlir.dialects import nvvm
from ...typing import Numeric, NumericMeta
from ... import core
from .mma import SmemLayoutAtomKind
@dsl_user_op
def make_smem_layout_atom(
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
) -> core.ComposedLayout:
"""
Makes a SMEM layout Atom.
This function creates a composed layout in unit of elements consistent with the requested layout
Atom kind and element data type.
:param kind: The kind of layout Atom
:type kind: SmemLayoutAtomKind
:param element_type: The element data type to construct the layout for
:type element_type: Type[Numeric]
:return: The SMEM layout atom
:rtype: core.ComposedLayout
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
num_contiguous_bits = 128
sw = core.make_swizzle(0, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
num_contiguous_bits = 256
sw = core.make_swizzle(1, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
num_contiguous_bits = 512
sw = core.make_swizzle(2, 4, 3)
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
num_contiguous_bits = 1024
sw = core.make_swizzle(3, 4, 3)
else:
raise ValueError("unrecognized SMEM layout atom kind")
num_contiguous_elems = num_contiguous_bits // element_type.width
if kind in (
SmemLayoutAtomKind.MN_INTER,
SmemLayoutAtomKind.MN_SW32,
SmemLayoutAtomKind.MN_SW64,
SmemLayoutAtomKind.MN_SW128,
):
# M/N-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
),
loc=loc,
ip=ip,
)
else:
# K-major layout
return core.make_composed_layout(
sw,
0,
core.make_layout(
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
),
loc=loc,
ip=ip,
)
@dsl_user_op
def fence(*, loc=None, ip=None) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-fence>`__.
"""
nvvm.wgmma_fence_aligned(loc=None, ip=None)
@dsl_user_op
def commit_group(*, loc=None, ip=None) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group>`__.
"""
nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip)
@dsl_user_op
def wait_group(group, *, loc=None, ip=None) -> None:
"""
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-wait-group>`__.
"""
nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip)

View File

@@ -0,0 +1,380 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from dataclasses import dataclass
from typing import Type
from cutlass.cutlass_dsl import CuTeDSL
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape, rank, depth
from ...typing import (
Shape,
Float16,
BFloat16,
Float32,
Boolean,
Float8E5M2,
Float8E4M3FN,
Numeric,
)
####################################################################################################
#
# MMA Ops and Traits
#
####################################################################################################
class OperandMajorMode(enum.Enum):
"""
An enumeration for the majorness of the input operands of the MMA.
"""
MN = _cute_ir.MajorMode.mn
K = _cute_ir.MajorMode.k
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
@classmethod
def _missing_(cls, value):
if isinstance(value, str):
value = value.upper()
if value == "MN":
return OperandMajorMode.MN
elif value == "K":
return OperandMajorMode.K
def _to_ir(self) -> _cute_ir.MajorMode:
return self.value
class OperandSource(enum.Enum):
"""
An enumeration for the source memory location of the A input operand of the MMA.
"""
RMEM = _cute_ir.MmaFragKind.rmem
SMEM = _cute_ir.MmaFragKind.smem_desc
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir(self) -> _cute_ir.MmaFragKind:
return self.value
class Field(enum.Enum):
"""
An enumeration for the fields of the MMA Atom that can be modified at runtime.
"""
ACCUMULATE = "accum_c"
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir_field_name(self) -> str:
return self.value
@dataclass(frozen=True)
class MmaOp(MmaOp):
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
shape_mnk: Shape
a_src: OperandSource
a_major_mode: OperandMajorMode
b_major_mode: OperandMajorMode
admissible_archs = ["sm_90a"]
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
# Verify that the user provided enum values
if not isinstance(self.a_src, OperandSource):
raise OpError(
self,
"expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance",
)
if not isinstance(self.a_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
)
if not isinstance(self.b_major_mode, OperandMajorMode):
raise OpError(
self,
"expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
)
# Verify instruction shape
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
raise OpError(
self,
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
f"but got {self.shape_mnk}",
)
m, n = self.shape_mnk[0], self.shape_mnk[1]
if m != 64:
raise OpError(self, f"expects the M-mode to be 64, but got {m}")
if (n < 8) or (n > 256) or (n % 8 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}",
)
def __str__(self) -> str:
return (
self.__class__.descriptive_name # type: ignore
+ f"\n A data type = {self.a_dtype}"
+ f"\n B data type = {self.b_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n A source location = {self.a_src}"
+ f"\n A major mode = {self.a_major_mode}"
+ f"\n B major mode = {self.b_major_mode}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
class MmaTrait(Trait):
admissible_fields = [Field.ACCUMULATE]
def set(self, field, value, *, loc=None, ip=None) -> None:
if field not in self.admissible_fields:
raise ValueError(
f"invalid field, must be {Field.ACCUMULATE}, but got {field}"
)
field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>"
attr = ir.Attribute.parse(field_name)
self.value = _cute_nvgpu_ir.atom_set_value(
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
@dataclass(frozen=True)
class MmaF16BF16Op(MmaOp):
"""
F16/BF16 warpgroup MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
"""
descriptive_name = "warpgroup F16/BF16 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
instruction_shape: Shape,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
ab_dtype,
ab_dtype,
acc_dtype,
instruction_shape,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self) -> None:
# Input data type verification
if self.a_dtype not in [Float16, BFloat16]:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
)
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
# Accumulator data type verification
if self.acc_dtype not in [Float16, Float32]:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
)
if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32):
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
)
# Verify the instruction shape
instruction_k = 16
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
shape_mnk.type.attribute,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.a_src._to_ir(),
)
return MmaF16BF16Trait(
_cute_nvgpu_ir.make_sm90_mma(
ty,
Boolean(False).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
class MmaF16BF16Trait(MmaTrait):
pass
@dataclass(frozen=True)
class MmaF8Op(MmaOp):
"""
F16/BF16 warpgroup MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
"""
descriptive_name = "warpgroup F8 MMA Operation"
def __init__(
self,
a_dtype: Type[Numeric],
b_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
instruction_shape: Shape,
a_src: OperandSource,
a_major_mode: OperandMajorMode,
b_major_mode: OperandMajorMode,
) -> None:
super().__init__(
a_dtype,
b_dtype,
acc_dtype,
instruction_shape,
a_src,
a_major_mode,
b_major_mode,
)
self._verify()
def _verify(self):
# Input data type verification
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
raise OpError(
self,
"expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
)
if self.b_dtype not in [Float8E5M2, Float8E4M3FN]:
raise OpError(
self,
"expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
)
# Accumulator data type verification
if self.acc_dtype != Float32:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
)
# Verify the instruction shape
instruction_k = 32
if rank(self.shape_mnk) == 2:
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
if self.shape_mnk[2] != instruction_k:
raise OpError(
self,
f"expects the instruction extent in the K-mode to be {instruction_k}, "
f"but got {self.shape_mnk[2]}",
)
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
shape_mnk.type.attribute,
self.a_major_mode._to_ir(),
self.b_major_mode._to_ir(),
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.a_src._to_ir(),
)
return MmaF8Trait(
_cute_nvgpu_ir.make_sm90_mma(
ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
)
)
class MmaF8Trait(MmaTrait):
pass
####################################################################################################
#
# SMEM layout atoms
#
####################################################################################################
class SmemLayoutAtomKind(enum.Enum):
"""
Enum class for the kinds of SMEM layout atoms for SM90.
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can
be used to construct an SMEM layout using blocked product for operand A or B such that the
resulting layout is legal for both TMA and UMMA.
Note that there are other ways of creating legal layouts for operand A and B.
"""
MN_INTER = enum.auto()
MN_SW32 = enum.auto()
MN_SW64 = enum.auto()
MN_SW128 = enum.auto()
K_INTER = enum.auto()
K_SW32 = enum.auto()
K_SW64 = enum.auto()
K_SW128 = enum.auto()

View File

@@ -0,0 +1,515 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import ctypes
from functools import lru_cache
import itertools
import operator
from time import time
from typing import Union
# MLIR modules imports
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
# Local modules imports
from .typing import (
AddressSpace,
Tensor,
Type,
Pointer,
Boolean,
Numeric,
Float4E2M1FN,
Int64,
Int32,
Int16,
Int8,
Uint64,
Uint32,
Uint16,
Uint8,
Float64,
Float32,
Float16,
BFloat16,
Float8E5M2,
)
from .core import find, _Tensor as CoreTensor
class _Pointer(Pointer):
"""Runtime representation of a pointer that can inter-operate with various data structures,
including numpy arrays and device memory.
:param pointer: The pointer to the data
:type pointer: int or pointer-like object
:param dtype: Data type of the elements pointed to
:type dtype: Type
:param mem_space: Memory space where the pointer resides, defaults to generic
:type mem_space: _cute_ir.AddressSpace, optional
:param assumed_align: Assumed alignment of input pointer in bytes, defaults to None
:type assumed_align: int, optional
:ivar _pointer: The underlying pointer
:ivar _dtype: Data type of the elements
:ivar _addr_space: Memory space of the pointer
:ivar _assumed_align: Alignment of the pointer in bytes
:ivar _desc: C-type descriptor for the pointer
:ivar _c_pointer: C-compatible pointer representation
"""
def __init__(
self,
pointer,
dtype,
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
assumed_align=None,
):
self._pointer = pointer
self._dtype = dtype
self._addr_space = mem_space
is_in_device = mem_space == _cute_ir.AddressSpace.gmem
if assumed_align is None:
if is_in_device:
self._assumed_align = 32
else:
self._assumed_align = dtype.width // 8
else:
self._assumed_align = assumed_align
class PtrDescriptor(ctypes.Structure):
"""A ctype descriptor for CuTe memref ptr"""
_fields_ = [("ptr", ctypes.c_void_p)]
def __str__(self):
return f"0x{self.ptr:016x}"
self._desc = PtrDescriptor(int(self._pointer))
self._c_pointer = ctypes.cast(ctypes.pointer(self._desc), ctypes.c_void_p)
assert (
self._desc.ptr % self._assumed_align == 0
), f"pointer must be {self._assumed_align} bytes aligned"
def size_in_bytes(self) -> int:
return ctypes.sizeof(self._desc)
def __get_mlir_types__(self):
return [self.mlir_type]
def __c_pointers__(self):
return [self._c_pointer]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
return values[0]
# Move mlir Type out of __init__ to decouple with mlir Context
@property
def mlir_type(self) -> ir.Type:
return _cute_ir.PtrType.get(
self._dtype.mlir_type, self._addr_space, self._assumed_align
)
@property
def element_type(self) -> Type[Numeric]:
return self._dtype
@property
def memspace(self):
return self._addr_space
def verify(self, expected_py_type):
if expected_py_type is Pointer:
return True
elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
return True
return False
def __str__(self) -> str:
return f"Ptr<0x{self._desc.ptr:016x}@{self._addr_space}>"
def __repr__(self):
return self.__str__()
class _Tensor(Tensor):
def __init__(
self,
tensor,
assumed_align=None,
):
# If tensor is already a DLPack object, use it directly
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
self._dlpack_data = tensor
else:
self._dlpack_data = tensor.__dlpack__()
self._dltensor_wrapper = None
self._assumed_align = assumed_align
self._is_dynamic = False
self._memref_desc = None
self._dtype = None
@property
def __class__(self) -> Type[Tensor]:
# Cheat to let `type(_Tensor())` to return cute.Tensor
return Tensor
@staticmethod
def lazily_load_dltensor(func):
"""Decorator to lazily load the DLTensorWrapper.
This decorator loads the DLTensorWrapper when needed,
avoiding overhead in the critical path of calling JIT functions.
"""
def wrapper(self, *args, **kwargs):
if self._dltensor_wrapper is None:
self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data)
return func(self, *args, **kwargs)
return wrapper
@lazily_load_dltensor
def mark_layout_dynamic(self, leading_dim: int | None = None):
"""Marks the tensor layout as dynamic based on the leading dimension.
:param leading_dim: The leading dimension of the layout, defaults to None
:type leading_dim: int, optional
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
if the layout cannot be automatically deduced.
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
with the existing layout by checking that the corresponding stride of that dimension is 1.
Limitation: only support flat layout for now. Will work on supporting nested layout in the future.
:return: The tensor with dynamic layout
:rtype: _Tensor
"""
self._dltensor_wrapper.mark_layout_dynamic(leading_dim)
return self
@lazily_load_dltensor
def mark_compact_shape_dynamic(
self,
mode: int,
stride_order: tuple[int, ...] | None = None,
divisibility: int = 1,
):
"""Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides.
:param mode: The mode of the compact shape, defaults to 0
:type mode: int
:param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None.
Indicates the order of the modes (dimensions) if the current layout were converted to row-major order.
It starts from the outermost to the innermost dimension.
:type stride_order: tuple[int, ...], optional
:param divisibility: The divisibility constraint for the compact shape, defaults to 1
:type divisibility: int, optional
:return: The tensor with dynamic compact shape
:rtype: _Tensor
If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout.
Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout).
An error is raised if automatic deduction fails.
If ``stride_order`` is explicitly specified, it does the consistency check with the layout.
For example:
- Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`)
- Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`).
Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor.
.. code-block:: python
a = torch.empty(3, 4)
t = cute.runtime.from_dlpack(a)
t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order())
"""
self._dltensor_wrapper.mark_compact_shape_dynamic(
mode, stride_order, divisibility
)
return self
@property
@lazily_load_dltensor
def element_type(self) -> Type[Numeric]:
if self._dtype is None:
self._dtype = self._dltensor_wrapper.dtype
return self._dtype
@element_type.setter
def element_type(self, new_type):
"""Set the element type of the tensor.
:warning: This API is added for narrow precision before we have a clean `recast_tensor` story.
:note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor
from frameworks with storage type like uint8.
**Example**:
.. code-block:: python
# Create a tensor from a numpy array
import numpy as np
from cutlass.cute import from_dlpack
# Create a tensor with Float32 elements
a = np.zeros(shape, dtype=np.uint8)
tensor = from_dlpack(a)
# Change the element type to Float4E2M1FN even storage type is uint8
tensor.element_type = cutlass.Float4E2M1FN
src = from_dlpack(... data tensor ...)
# convert and initialize narrow precision tensor
cute.testing.convert(src, tensor)
"""
self._dtype = new_type
@property
@lazily_load_dltensor
def memspace(self):
return self._dltensor_wrapper.address_space
@property
@lazily_load_dltensor
def size_in_bytes(self) -> int:
return self._dltensor_wrapper.size_in_bytes()
@property
@lazily_load_dltensor
def mlir_type(self) -> ir.Type:
return self._dltensor_wrapper.get_type(
self.element_type.mlir_type, self._assumed_align
)
@lazily_load_dltensor
def __str__(self) -> str:
return f"Tensor<0x{self._dltensor_wrapper.str}>"
def __repr__(self):
return self.__str__()
def __setitem__(self, crd, value):
raise TypeError(f"runtime._Tensor is not indexable")
def __getitem__(self, crd):
raise TypeError(f"runtime._Tensor is not indexable")
@property
@lazily_load_dltensor
def iterator(self):
return _Pointer(
self._dltensor_wrapper.data_ptr,
self.element_type,
self.memspace,
self._assumed_align,
)
@property
def layout(self):
raise NotImplementedError(
f"layout property is not supported in runtime, support in future"
)
@property
@lazily_load_dltensor
def shape(self):
return self._dltensor_wrapper.shape
@property
@lazily_load_dltensor
def stride(self):
strides = self._dltensor_wrapper.stride
if strides is None:
strides = itertools.accumulate(
reversed(self.shape), func=operator.mul, initial=1
)
strides = tuple(reversed(list(strides)[:-1]))
return strides
@property
@lru_cache(maxsize=128, typed=True)
def leading_dim(self):
"""Get the leading dimension of this Tensor.
:return: The leading dimension index or indices
:rtype: int or tuple or None
The return value depends on the tensor's stride pattern:
* If a single leading dimension is found, returns an integer index
* If nested leading dimensions are found, returns a tuple of indices
* If no leading dimension is found, returns None
"""
return find(1, self.stride, exclude_when=(1, self.shape))
def fill(self, value: Numeric):
raise TypeError(f"fill function is not supported in runtime")
@property
@lazily_load_dltensor
def data_ptr(self):
return self._dltensor_wrapper.data_ptr
@lazily_load_dltensor
def __c_pointers__(self):
self._memref_desc = self._dltensor_wrapper.build_memref_desc(
self._assumed_align
)
return [_cute_ir.pycapsule_get_pointer(self._memref_desc)]
def __get_mlir_types__(self):
return [self.mlir_type]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
assert isinstance(values[0], CoreTensor)
return CoreTensor(values[0].value, self._dtype)
def from_dlpack(
tensor_dlpack,
assumed_align=None,
) -> Tensor:
"""Convert from tensor object supporting __dlpack__() to a CuTe Tensor.
:param tensor_dlpack: Tensor object that supports the DLPack protocol
:type tensor_dlpack: object
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None,
if None, will use the element size bytes as the assumed alignment.
:type assumed_align: int, optional
:return: A CuTe Tensor object
:rtype: Tensor
Examples:
.. code-block:: python
import torch
from cutlass.cute.runtime import from_dlpack
x = torch.randn(100, 100)
y = from_dlpack(x)
y.shape
# (100, 100)
type(y)
# <class 'cutlass.cute.Tensor'>
"""
return _Tensor(
tensor_dlpack,
assumed_align=assumed_align,
)
def make_ptr(
dtype: Type[Numeric],
value: Union[int, ctypes._Pointer],
mem_space: AddressSpace = AddressSpace.generic,
assumed_align=None,
) -> Pointer:
"""Create a pointer from a memory address
:param dtype: Data type of the pointer elements
:type dtype: Type[Numeric]
:param value: Memory address as integer or ctypes pointer
:type value: Union[int, ctypes._Pointer]
:param mem_space: Memory address space, defaults to AddressSpace.generic
:type mem_space: AddressSpace, optional
:param align_bytes: Alignment in bytes, defaults to None
:type align_bytes: int, optional
:return: A pointer object
:rtype: Pointer
.. code-block:: python
import numpy as np
import ctypes
from cutlass import Float32
from cutlass.cute.runtime import make_ptr
# Create a numpy array
a = np.random.randn(16, 32).astype(np.float32)
# Get pointer address as integer
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
# Create pointer from address
y = make_ptr(cutlass.Float32, ptr_address)
# Check properties
print(y.element_type)
print(type(y)) # <class 'cutlass.cute.Pointer'>
"""
# check if value is int or ctypes.POINTER
if isinstance(value, int):
address_value = value
elif isinstance(value, ctypes._Pointer):
# get address value
address_value = ctypes.cast(value, ctypes.c_void_p).value
assert address_value is not None, "Pointer address is None"
else:
raise TypeError(
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
)
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)
class TensorAdapter:
"""
Convert a DLPack protocol supported tensor/array to a cute tensor.
"""
# Need reference these capsules to avoid being garbage collected
tensor_capsules = []
def __init__(self, arg):
self._arg = from_dlpack(arg).mark_layout_dynamic()
self.tensor_capsules.append(self._arg)
def __new_from_mlir_values__(self, values):
return self._arg.__new_from_mlir_values__(values)
def __c_pointers__(self):
return self._arg.__c_pointers__()
def __get_mlir_types__(self):
return self._arg.__get_mlir_types__()
# -------------------------------------------------------------------------
# Try to register_jit_arg_adapter for TensorAdapter
# -------------------------------------------------------------------------
try: # Register for numpy.ndarray
import numpy
JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter)
except ImportError:
pass # silent attempt, suppress error
try: # Register for torch.Tensor
import torch
JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter)
except ImportError:
pass # silent attempt, suppress error

View File

@@ -0,0 +1,285 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import random
import numpy as np
import functools
import hashlib
from cutlass.cutlass_dsl import (
const,
T,
CuTeDSL,
BaseDSL,
t,
Constexpr,
detect_gpu_arch,
)
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.ir as ir
from cutlass._mlir.dialects import nvvm, cf, vector, builtin
from cutlass.cute import core
from cutlass.cute import nvgpu
from typing import Type
from inspect import isclass
def assert_(cond, msg=None):
if isinstance(cond, ir.Value):
if ir.VectorType.isinstance(cond.type):
assert (
cond.type.element_type == T.bool()
), f"only expects vector type with boolean elements, but got {cond.type}"
cond_val = vector.multi_reduction(
vector.CombiningKind.AND, cond, const(True), range(cond.type.rank)
)
else:
cond_val = cond
else:
cond_val = const(cond, t.Boolean)
cf.assert_(cond_val, msg if msg else "")
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
if src.element_type.width == 4:
tv_layout = core.recast_layout(8, 4, tv_layout)
src = core.recast_tensor(src, dtype=t.Int8)
return src, tv_layout
def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]):
"""Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit.
:param input: The input tensor to recast.
:param dtype: The target numeric type to potentially recast to.
:raises TypeError: If dtype is not a subclass of Numeric.
:return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged.
"""
if not isclass(dtype) or not issubclass(dtype, core.Numeric):
raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}")
if dtype.width == 4:
recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape
i4_vec = vector.bitcast(
T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast()
)
res_vect = builtin.unrealized_conversion_cast(
[T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec]
)
return core.TensorSSA(res_vect, recast_shape, dtype)
return input
def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]):
"""Conditionally recasts the tensor from 4-bit type if the source type is 4-bit.
:param input: The input tensor to recast.
:param src_dtype: The source numeric type to potentially recast from.
:raises TypeError: If src_dtype is not a subclass of Numeric.
:return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged.
"""
if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric):
raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}")
if src_dtype.width == 4:
recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape
i4_vec = builtin.unrealized_conversion_cast(
[T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()]
)
res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec)
return core.TensorSSA(res_vect, recast_shape, core.Int8)
return input
@CuTeDSL.kernel
def _convert_kernel(
gSrc: core.Tensor,
gDst: core.Tensor,
cSrc: core.Tensor,
src_tv_layout: core.Layout,
dst_tv_layout: core.Layout,
src_shape: core.Shape,
src_ty,
dst_ty,
):
tidx = nvvm.read_ptx_sreg_tid_x(T.i32())
bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32())
cta_coord = (None, bidx)
# logical idx -> address
ctaSrc = gSrc[cta_coord] # (...,TileV,...)
ctaDst = gDst[cta_coord] # (...,TileV,...)
ctaCSrc = cSrc[cta_coord] # (...,TileV,...)
# print(f"ctaSrc = {ctaSrc.type}")
# compose with CTA TV layout
# tid, vid -> address
tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V)
tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V)
tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V)
# print(f"tidfrgSrc = {tidfrgSrc.type}")
# slice for threads
thr_coord = (tidx, None)
thrSrc = tidfrgSrc[thr_coord] # (V)
thrDst = tidfrgDst[thr_coord] # (V)
thrCSrc = tidfrgCSrc[thr_coord] # (V)
# print(f"thrSrc = {thrSrc.type}")
# predicate
if core.elem_less(thrCSrc[0], src_shape):
# allocate fragments for gmem->rmem
frgSrc = core.make_fragment(
core.get(src_tv_layout, mode=[1]), gSrc.element_type
) # (V)
frgDst = core.make_fragment(
core.get(dst_tv_layout, mode=[1]), gDst.element_type
) # (V)
# print(f"frgSrc = {frgSrc.type}")
# Move data to reg address space
copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type)
core.copy(copy_atom_load, thrSrc, frgSrc)
vec_src = frgSrc.load()
vec_src = _maybe_recast_to_f4(vec_src, src_ty)
vec_dst = vec_src.to(dst_ty)
vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty)
frgDst.store(vec_dst)
# Copy the results back to c
copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type)
core.copy(copy_atom_stg, frgDst, thrDst)
@CuTeDSL.jit(preprocess=False)
def _convert(
src: core.Tensor,
dst: core.Tensor,
leading_mode: Constexpr,
elem_per_copy: Constexpr,
):
# Step 1. figure proper tv_layout
src_ty = src.element_type
dst_ty = dst.element_type
tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1))
# Step 2. maybe recast from f4 tensor
src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout)
dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout)
src_shape = src.shape
# predicate tensor
idA = core.make_identity_tensor(src.shape)
# Step 3. select a proper tiling pattern as (...,TileV, ...)
src_cta_tiler = [
1,
] * core.rank(src.layout)
src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...)
dst_cta_tiler = [
1,
] * core.rank(dst.layout)
dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...)
# Step 4. partition input and output tensor by cta tiler.
gS = core.zipped_divide(
src, tuple(src_cta_tiler)
) # ((...,TileV,...),(...,RestV,...))
cS = core.zipped_divide(
idA, tuple(src_cta_tiler)
) # ((...,TileV,...),(...,RestV,...))
gD = core.zipped_divide(
dst, tuple(dst_cta_tiler)
) # ((...,TileV,...),(...,RestV,...))
# print(f"{gS.type=}")
_convert_kernel(
gS,
gD,
cS,
src_tv_layout,
dst_tv_layout,
src_shape,
src_ty,
dst_ty,
).launch(
grid=[core.size(gS, mode=[1]), 1, 1],
block=[core.size(src_tv_layout, mode=[0]), 1, 1],
)
# Converts from src tensor to dst tensor, their logical shape are required to be the same.
# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of
# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext
# needs 32-bits aligned input/output)
def convert(src: core.Tensor, dst: core.Tensor):
assert len(src.shape) == len(
dst.shape
), "Shape of src and dst tensors should be the same rank."
# find leading mode
leading_mode = np.argmin([np.min(s) for s in src.stride])
elem_per_copy = 2
if src.element_type.width == 4 or dst.element_type.width == 4:
elem_per_copy = 8
elif src.element_type.width == 8 or dst.element_type.width == 8:
elem_per_copy = 4
assert (
src.shape[leading_mode] % elem_per_copy == 0
and dst.shape[leading_mode] % elem_per_copy == 0
)
_convert(src, dst, leading_mode, elem_per_copy)
#########################################
# Testing utilities
#########################################
def sample_pytest(rand_cfg=None):
"""
Decorator to randomly sample pytest parametrized tests.
rand_cfg: Tuple[int, float] - (random_seed, sample_ratio)
Sampling is disabled when:
- A specific test is selected (via -k or direct test path)
- Not running under pytest
"""
import functools
import os
import random
import pytest
import sys
seed, sample_ratio = rand_cfg
random.seed(seed)
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ:
# Check if test was explicitly selected like ::test_name[param1-param2-...]
if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv):
# Test was explicitly selected, don't skip
return func(*args, **kwargs)
if random.uniform(0.0, 1.0) > sample_ratio:
pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})")
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from abc import ABC, abstractmethod
from typing import ForwardRef, Tuple, Union, Any, Type, List
from cutlass.base_dsl.typing import *
from cutlass._mlir import ir
import cutlass._mlir.extras.types as T
from cutlass._mlir.dialects.cute import AddressSpace
Int = Union[int, Integer]
ScaledBasis = ForwardRef("ScaledBasis")
IntTuple = Union[Int, Tuple["IntTuple", ...]]
Shape = Union[Int, Tuple["Shape", ...]]
Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]]
Coord = Union[Int, None, Tuple["Coord", ...]]
class Layout(ir.Value):
def __init__(self, op_result):
super().__init__(op_result)
def __str__(self): ...
def get_hier_coord(self, idx) -> Coord:
"""Return the (hierarchical) ND logical coordinate corresponding to the linear index"""
...
@property
def shape(self, *, loc=None, ip=None) -> Shape: ...
@property
def stride(self, *, loc=None, ip=None) -> Stride: ...
Tile = Union[Int, None, Layout, Tuple["Tile", ...]]
# XTuple is super set of above types
XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
Tiler = Union[Shape, Layout, Tile]
class Pointer:
"""
Abstract base class for CuTe jit function and runtime _Pointer
"""
def __extract_mlir_values__(self):
# Doesn't matter just return a value
return [self]
class Tensor(ABC):
"""
Abstract base class for CuTe jit function and runtime _Tensor
A CuTe Tensor is iterator with layout
:Examples:
Create tensor from torch.tensor with Host Runtime:
.. code-block:: python
>>> import torch
>>> from cutlass.cute.runtime import from_dlpack
>>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32))
>>> mA.shape
(3,)
>>> mA.stride
(1,)
>>> mA.layout
(3,):(1,)
Define JIT function:
.. code-block:: python
@cute.jit
def add(a: Tensor, b: Tensor, res: Tensor): ...
Call JIT function from python:
.. code-block:: python
>>> import torch
>>> a = torch.tensor([1, 3, 5], dtype=torch.int32)
>>> b = torch.tensor([2, 4, 6], dtype=torch.int32)
>>> c = torch.zeros([3], dtype=torch.int32)
>>> mA = from_dlpack(a)
>>> mB = from_dlpack(b)
>>> mC = from_dlpack(c)
>>> add(mA, mB, mC)
>>> c
tensor([3, 7, 11], dtype=torch.int32)
"""
def __str__(self): ...
@abstractmethod
def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ...
@abstractmethod
def __setitem__(self, idx, value): ...
@property
@abstractmethod
def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ...
@element_type.setter
def element_type(self, new_type): ...
@property
@abstractmethod
def memspace(self) -> AddressSpace: ...
@property
@abstractmethod
def iterator(self): ...
@property
def layout(self) -> Union[Layout, "ComposedLayout"]: ...
@property
def shape(self) -> Shape: ...
def load(self, *, loc=None, ip=None) -> "TensorSSA": ...
def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
def mark_layout_dynamic(self, leading_dim: int|None = None) -> "Tensor": ...
def mark_compact_shape_dynamic(
self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1
) -> "Tensor": ...
@abstractmethod
def fill(self, value: Numeric) -> None: ...
__all__ = [
"Coord",
"Numeric",
"Integer",
"Boolean",
"Int8",
"Int16",
"Int32",
"Int64",
"Uint8",
"Uint16",
"Uint32",
"Uint64",
"Float",
"Float16",
"BFloat16",
"TFloat32",
"Float32",
"Float64",
"Float8E5M2",
"Float8E4M3FN",
"Float8E4M3B11FNUZ",
"Float8E4M3",
"Float8E8M0FNU",
"Float4E2M1FN",
"Float6E2M3FN",
"Float6E3M2FN",
"IntTuple",
"Layout",
"Pointer",
"Shape",
"Stride",
"Tensor",
"Tile",
"Tiler",
"XTuple",
]

View File

@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
def check_value_in(
value, possible_values: list, value_description: str, prefix=""
) -> None:
if value not in possible_values:
err_msg = prefix
if err_msg != "":
err_msg += ": "
err_msg += f"invalid {value_description}, got {value}, must be one of {possible_values}"
raise ValueError(err_msg)
def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None:
if not isinstance(ty, type):
ty = type(ty)
if ty not in possible_types:
err_msg = prefix
if err_msg != "":
err_msg += ": "
err_msg += f"invalid type for {type_description}, got {ty}, must be one of {possible_types}"
raise TypeError(err_msg)

View File

@@ -0,0 +1,169 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Type, Union
from cutlass.cute.typing import (
Numeric,
Boolean,
Float,
Integer,
TFloat32,
Float8E4M3B11FNUZ,
Float8E4M3FN,
Float8E5M2,
Float8E8M0FNU,
Float4E2M1FN,
Tensor,
)
from cutlass.cute.runtime import from_dlpack
import cutlass.cute as cute
import torch
def dtype(ty: Type[Numeric]):
"""
Return the corresponding torch.dtype per the given DSL type
"""
torch_dtype = getattr(torch, ty.__name__.lower(), None)
torch_type_map = {
Boolean: torch.bool,
# TFloat32 is just alias of float32
TFloat32: torch.float32,
Float8E5M2: torch.float8_e5m2,
Float8E4M3FN: torch.float8_e4m3fn,
Float8E4M3B11FNUZ: torch.float8_e4m3fnuz,
}
if torch_dtype is None:
torch_dtype = torch_type_map.get(ty)
if torch_dtype is None:
raise TypeError(f"{ty} is not supported by torch")
return torch_dtype
@dataclass
class ScalarInitConfig:
"""Configuration for scalar initialization"""
value: float = 0.0
@dataclass
class RandomInitConfig:
"""Configuration for random initialization"""
min_val: int = -2
max_val: int = 2
@dataclass
class GaussianInitConfig:
"""Configuration for Gaussian initialization"""
mean: float = 0.0
std: float = 1.0
scale: float = 1.0
class TensorInitType(Enum):
"""Enumeration of tensor initialization types"""
SKIP = "skip"
SCALAR = "scalar"
RANDOM = "random"
GAUSSIAN = "gaussian"
def create_and_permute_torch_tensor(
shape,
dtype: "torch.dtype",
permute_order=None,
init_type: TensorInitType = TensorInitType.RANDOM,
init_config: Optional[
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
] = None,
) -> "torch.Tensor":
"""
Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config
"""
init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32
init_torch_tensor = torch.empty(*shape, dtype=init_dtype)
if init_type == TensorInitType.SKIP:
assert init_config is None
f32_torch_tensor = init_torch_tensor
elif init_type == TensorInitType.SCALAR:
if init_config is None:
init_config = ScalarInitConfig()
else:
if not isinstance(init_config, ScalarInitConfig):
raise ValueError("init_config must be ScalarInitConfig()")
f32_torch_tensor = init_torch_tensor.fill_(init_config.value)
elif init_type == TensorInitType.RANDOM:
if init_config is None:
init_config = RandomInitConfig()
else:
if not isinstance(init_config, RandomInitConfig):
raise ValueError("init_config must be RandomInitConfig()")
f32_torch_tensor = init_torch_tensor.random_(
init_config.min_val, init_config.max_val
).to(dtype=torch.float32)
elif init_type == TensorInitType.GAUSSIAN:
if init_config is None:
init_config = GaussianInitConfig()
else:
if not isinstance(init_config, GaussianInitConfig):
raise ValueError("init_config must be GaussianInitConfig()")
f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std)
f32_torch_tensor = f32_torch_tensor * (1 << init_config.scale)
else:
raise ValueError(f"Invalid init type: {init_type}")
if permute_order is not None:
f32_torch_tensor = f32_torch_tensor.permute(permute_order)
dtype_torch_tensor = f32_torch_tensor.to(dtype=dtype)
return dtype_torch_tensor
def convert_cute_tensor(
f32_torch_tensor: "torch.Tensor",
cute_tensor: Tensor,
dtype: Type[Numeric],
is_dynamic_layout: bool = True,
) -> Tensor:
"""
Change the value of the cute tensor to make its value converted from a fp32 torch tensor.
Used for fp8 types tensor creatation now.
"""
# if torch_tensor is on cpu, create a gpu copy
if f32_torch_tensor.device.type == "cpu":
f32_torch_tensor = f32_torch_tensor.cuda()
# Fp8 type need explicit type conversion
if dtype in {
Float8E5M2,
Float8E4M3FN,
Float8E8M0FNU,
Float4E2M1FN,
}:
fp32_cute_tensor = from_dlpack(f32_torch_tensor)
if is_dynamic_layout:
fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic(
f32_torch_tensor.dim_order()[-1]
)
# Copy and convert from f32 cute tensor to dtype cute tensor
cute.testing.convert(fp32_cute_tensor, cute_tensor)
return cute_tensor

View File

@@ -0,0 +1,9 @@
# Utilities
This folder contains various utilties for kernel authoring. Specifically, the implementation of the
followings can be considered experimental and subject to breaking changes:
- static persistent tile scheduler defined in [`static_persistent_tile_scheduler.py`](./static_persistent_tile_scheduler.py)
- pipeline abstractions defined in [`pipeline.py`](./pipeline.py)
- grouped GEMM utilties defined [`grouped_gemm_tile_scheduler_helper.py`](./grouped_gemm_tile_scheduler_helper.py)
and [`tensormap_manager.py`](./tensormap_manager.py)

View File

@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .static_persistent_tile_scheduler import (
WorkTileInfo,
PersistentTileSchedulerParams,
StaticPersistentTileScheduler,
)
from .pipeline import (
Agent,
CooperativeGroup,
PipelineUserType,
PipelineState,
make_pipeline_state,
PipelineAsync,
PipelineTmaAsync,
PipelineTmaUmma,
PipelineUmmaAsync,
PipelineTmaStore,
pipeline_init_wait,
)
from .hardware_info import (
HardwareInfo,
)
from .blackwell_helpers import (
compute_epilogue_tile_shape,
get_smem_store_op,
get_tmem_load_op,
get_num_tmem_alloc_cols,
make_smem_layout_a,
make_smem_layout_b,
make_smem_layout_epi,
make_trivial_tiled_mma,
)
from .hopper_helpers import (
sm90_get_smem_store_op,
)
from .grouped_gemm_tile_scheduler_helper import (
GroupSearchResult,
GroupedGemmGroupSearchState,
GroupedGemmTileSchedulerHelper,
create_initial_search_state,
)
from .tensormap_manager import (
TensorMapUpdateMode,
TensorMapManager,
)
from .smem_allocator import SmemAllocator
from .layout import LayoutEnum
__all__ = [
"WorkTileInfo",
"PersistentTileSchedulerParams",
"StaticPersistentTileScheduler",
"TensorMapUpdateMode",
"TensorMapManager",
"GroupSearchResult",
"GroupedGemmGroupSearchState",
"create_initial_search_state",
"GroupedGemmTileSchedulerHelper",
"HardwareInfo",
]

View File

@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from enum import Enum
class SmemCapacity(Enum):
SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024
SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value,
"sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value,
"sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value,
}

View File

@@ -0,0 +1,910 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from enum import Enum
from math import log2, ceil
from typing import List, Type, Union, Tuple
from cutlass.cutlass_dsl import (
Float16,
BFloat16,
TFloat32,
Float32,
Uint8,
Int8,
Float8E4M3FN,
Float8E5M2,
Numeric,
NumericMeta,
dsl_user_op,
)
import cutlass.cute as cute
from cutlass.cute.nvgpu.common import CopyUniversalOp
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp
from cutlass.cute.nvgpu.tcgen05 import (
MmaF16BF16Op,
MmaTF32Op,
MmaI8Op,
MmaFP8Op,
OperandSource,
OperandMajorMode,
CtaGroup,
Ld16x64bOp,
Ld16x128bOp,
Ld16x256bOp,
Ld16x32bx2Op,
Ld32x32bOp,
Repetition,
Pack,
find_tmem_tensor_col_offset,
SmemLayoutAtomKind,
make_smem_layout_atom,
tile_to_mma_shape,
is_tmem_load,
get_tmem_copy_properties,
)
from cutlass.utils.layout import LayoutEnum
@dsl_user_op
def compute_epilogue_tile_shape(
cta_tile_shape: cute.Shape,
use_2cta_instrs: bool,
layout_d: LayoutEnum,
elem_ty_d: Type[Numeric],
*,
layout_c: LayoutEnum = None,
elem_ty_c: Union[Type[Numeric], None] = None,
loc=None,
ip=None,
) -> cute.Tile:
"""Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
:param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile, where
cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1]
corresponds to the width (N) of the tile.
:type cta_tile_shape: cute.Shape
:param use_2cta_instrs: A flag indicating whether the configuration is for a 2SM setup.
:type use_2cta_instrs: bool
:param layout_d: The layout enum of the output tensor D.
:type layout_d: LayoutEnum
:param elem_ty_d: The element type of output tensor D.
:type elem_ty_d: Type[Numeric]
:param layout_c: The layout enum of the input tensor C. Defaults to None.
:type layout_c: LayoutEnum, optional
:param elem_ty_c: The element type for input tensor C. Defaults to None.
:type elem_ty_c: Union[Type[Numeric], None], optional
:return: Returns epilog tiler, which is used in subsequent epilog partitions.
:rtype: cute.Tile
:raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions.
"""
def validate_type(ty, ty_name):
if not isinstance(ty, NumericMeta):
raise TypeError(f"{ty_name} must be Numeric, but got {ty}")
validate_type(elem_ty_d, "elem_ty_d")
if elem_ty_c is not None:
validate_type(elem_ty_c, "elem_ty_c")
cta_m, cta_n = cta_tile_shape[:2]
(warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1)
disable_source = elem_ty_c == None
max_bits = (
elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width)
)
dp_full = 32
tile_m = min(cta_m, dp_full * warp_m)
n_perf = 0
if disable_source:
if max_bits == 4:
compute_elts = 8192
else:
compute_elts = 4096
n_perf = compute_elts // tile_m
else:
if max_bits == 32:
n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32
elif max_bits == 16:
n_perf = 32 if cta_n <= 128 else 64
else:
n_perf = 64
d_is_m_major = layout_d.is_m_major_c()
c_is_m_major = True if layout_c is None else layout_c.is_m_major_c()
n_min_d = (
8 * warp_n
if d_is_m_major
else (128 * warp_n if elem_ty_d.width == 6 else 128 // elem_ty_d.width * warp_n)
)
n_min_c = (
8 * warp_n
if (c_is_m_major or disable_source)
else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n)
)
tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d))
if cta_n < n_min_c or cta_n < n_min_d:
raise ValueError(f"CTA tile too small: {cta_tile_shape=}")
# stride by tmem warp layout and return a by-mode tiler
tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip)
tile_n_layout = cute.make_layout(
(tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip
)
return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip))
@dsl_user_op
def get_smem_store_op(
layout_d: LayoutEnum,
elem_ty_d: Type[Numeric],
elem_ty_acc: Type[Numeric],
tiled_tmem_load: cute.TiledCopy,
*,
loc=None,
ip=None,
) -> cute.CopyAtom:
"""Selects the largest vectorized smem store atom available subject to
constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership.
:param layout_d: The layout enum of the output tensor D.
:type layout_d: LayoutEnum
:param elem_ty_d: The element type for output tensor D.
:type elem_ty_d: Type[Numeric]
:param elem_ty_acc: The element type for accumulator.
:type elem_ty_acc: Type[Numeric]
:param tiled_tmem_load: An instance of TiledCopy that represents the tmem load operation.
:type tiled_tmem_load: cute.TiledCopy
:return: Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
:rtype: cute.CopyAtom
"""
def validate_type(ty, ty_name):
if not isinstance(ty, NumericMeta):
raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")
validate_type(elem_ty_d, "elem_ty_d")
validate_type(elem_ty_acc, "elem_ty_acc")
is_m_major = layout_d.is_m_major_c()
is_n_major = layout_d.is_n_major_c()
if not is_tmem_load(tiled_tmem_load):
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
num_dp, num_bits, num_rep, pack = get_tmem_copy_properties(tiled_tmem_load)
use_stmatrix_m8n8_4x = (
all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 32,
is_n_major,
num_dp == 16,
num_bits == 128,
num_rep in (2, 4, 8, 16, 32, 64),
pack == Pack.NONE,
]
)
or all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 16,
num_dp == 16,
num_bits == 256,
num_rep in (2, 4, 8, 16, 32),
pack == Pack.NONE,
]
)
or all(
[
elem_ty_acc.width == 16,
elem_ty_d.width == 16,
num_dp == 16,
num_bits == 128,
num_rep in (2, 4, 8, 16, 32, 64),
pack == Pack.PACK_16b_IN_32b,
]
)
)
use_stmatrix_m16n8_4x = all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 8,
is_m_major,
num_dp == 16,
num_bits == 256,
num_rep in (4, 8, 16, 32),
pack == Pack.NONE,
]
)
use_stmatrix_m8n8_2x = (
all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 32,
is_n_major,
num_dp == 16,
num_bits == 128,
num_rep == 1,
pack == Pack.NONE,
]
)
or all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 16,
num_dp == 16,
num_bits == 256,
num_rep == 1,
pack == Pack.NONE,
]
)
or all(
[
elem_ty_acc.width == 16,
elem_ty_d.width == 16,
num_dp == 16,
num_bits == 128,
num_rep == 1,
pack == Pack.PACK_16b_IN_32b,
]
)
)
use_stmatrix_m16n8_2x = all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 8,
is_m_major,
num_dp == 16,
num_bits == 256,
num_rep == 2,
pack == Pack.NONE,
]
)
use_stmatrix_m16n8_1x = all(
[
elem_ty_acc.width == 32,
elem_ty_d.width == 8,
is_m_major,
num_dp == 16,
num_bits == 256,
num_rep == 1,
pack == Pack.NONE,
]
)
if use_stmatrix_m8n8_4x:
op = StMatrix8x8x16bOp(is_m_major, 4)
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
elif use_stmatrix_m8n8_2x:
op = StMatrix8x8x16bOp(is_m_major, 2)
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
elif use_stmatrix_m16n8_4x:
op = StMatrix16x8x8bOp(4)
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
elif use_stmatrix_m16n8_2x:
op = StMatrix16x8x8bOp(2)
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
elif use_stmatrix_m16n8_1x:
op = StMatrix16x8x8bOp(1)
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
else:
op = CopyUniversalOp()
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
@dsl_user_op
def get_tmem_load_op(
cta_tile_shape: cute.Shape,
layout_d: LayoutEnum,
elem_ty_d: Type[Numeric],
elem_ty_acc: Type[Numeric],
epi_tile: cute.Tile,
use_2cta_instrs: bool,
*,
loc=None,
ip=None,
) -> cute.CopyAtom:
"""Finds a performant TMEM_LOAD copy op for the selected epilogue
tile (epi_tile), element types, and tcgen05.mma instruction used.
:param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile.
:type cta_tile_shape: cute.Shape
:param layout_d: The layout enum of the output tensor D.
:type layout_d: LayoutEnum
:param elem_ty_d: The element type for output tensor D.
:type elem_ty_d: Type[Numeric]
:param elem_ty_acc: The element type for accumulation.
:type elem_ty_acc: Type[Numeric]
:param epi_tile: The epilogue tile configuration.
:type epi_tile: cute.Tile
:param use_2cta_instrs: A flag indicating whether the configuration is for 2 SMs.
:type use_2cta_instrs: bool
:return: An instance of Sm100TmemLoad with the computed configuration.
:rtype: cute.CopyAtom
:raises ValueError: If the function cannot handle the given combination of accumulation
and dimension types, or if it cannot determine the appropriate configuration based on
the input parameters.
"""
is_m_major = layout_d.is_m_major_c()
acc_bits = elem_ty_acc.width
d_bits = elem_ty_d.width
tmem_warp_shape_mn = (
(2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1)
)
epilog_tile_shape_mn = cute.product_each(
cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip
)
epilog_warp_tile_shape_mn = cute.shape_div(
epilog_tile_shape_mn, tmem_warp_shape_mn, loc=loc, ip=ip
)
num_dp = cute.size(epilog_warp_tile_shape_mn[0], loc=loc, ip=ip)
if num_dp not in {16, 32}:
raise ValueError("Cta tile and 2sm config does not generate correct num dp.")
num_col_bits = cute.size(epilog_warp_tile_shape_mn[1], loc=loc, ip=ip) * acc_bits
tmem_dp = 0
tmem_bit = 0
tmem_rep = 0
tmem_pack16b = False
if acc_bits == 32 and d_bits == 32:
if num_dp == 16:
if is_m_major:
tmem_dp = 16
tmem_bit = 256
else:
tmem_dp = 16
tmem_bit = 128
else:
tmem_dp = 32
tmem_bit = 32
elif acc_bits == 32 and d_bits == 16:
if num_dp == 16:
if is_m_major:
tmem_dp = 16
tmem_bit = 256
else:
tmem_dp = 16
tmem_bit = 256
else:
if is_m_major:
tmem_dp = 16
tmem_bit = 256
else:
tmem_dp = 32
tmem_bit = 32
elif acc_bits == 32 and d_bits == 8:
if num_dp == 16:
if is_m_major:
tmem_dp = 16
tmem_bit = 256
else:
tmem_dp = 16
tmem_bit = 32
else:
if is_m_major:
tmem_dp = 16
tmem_bit = 256
else:
tmem_dp = 32
tmem_bit = 32
elif acc_bits == 16 and d_bits == 16:
tmem_pack16b = True
if num_dp == 16:
if is_m_major:
tmem_dp = 16
tmem_bit = 128
else:
tmem_dp = 16
tmem_bit = 128
else:
if is_m_major:
tmem_dp = 16
tmem_bit = 128
else:
tmem_dp = 32
tmem_bit = 32
elif acc_bits == 32 and d_bits == 6:
if not num_dp == 32:
raise ValueError("Num dp must be 32.")
tmem_dp = 32
tmem_bit = 32
elif acc_bits == 32 and d_bits == 4:
if not num_dp == 32:
raise ValueError("Num dp must be 32.")
tmem_dp = 32
tmem_bit = 32
else:
raise ValueError(
f"Can not handle acc/d type combination: {elem_ty_acc=}, {elem_ty_d=}"
)
num_bit_div = tmem_bit
if tmem_dp == 16 and tmem_bit == 32:
num_bit_div = 64
if (num_col_bits % (num_bit_div * 128) == 0) and (
(tmem_dp == 16 and tmem_bit == 64)
or (tmem_dp == 16 and tmem_bit == 32)
or (tmem_dp == 32 and tmem_bit == 32)
):
tmem_rep = 128
elif (num_col_bits % (num_bit_div * 64) == 0) and (
(tmem_dp == 16 and tmem_bit == 128)
or (tmem_dp == 16 and tmem_bit == 64)
or (tmem_dp == 16 and tmem_bit == 32)
or (tmem_dp == 32 and tmem_bit == 32)
):
tmem_rep = 64
elif num_col_bits % (num_bit_div * 32) == 0:
tmem_rep = 32
elif num_col_bits % (num_bit_div * 16) == 0:
tmem_rep = 16
elif num_col_bits % (num_bit_div * 8) == 0:
tmem_rep = 8
elif num_col_bits % (num_bit_div * 4) == 0:
tmem_rep = 4
elif num_col_bits % (num_bit_div * 2) == 0:
tmem_rep = 2
elif num_col_bits % (num_bit_div * 1) == 0:
tmem_rep = 1
else:
raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.")
if tmem_dp == 16 and tmem_bit == 64:
op = Ld16x64bOp(
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
)
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
elif tmem_dp == 16 and tmem_bit == 128:
op = Ld16x128bOp(
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
)
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
elif tmem_dp == 16 and tmem_bit == 256:
op = Ld16x256bOp(
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
)
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
elif tmem_dp == 16 and tmem_bit == 32:
op = Ld16x32bx2Op(
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
)
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
elif tmem_dp == 32 and tmem_bit == 32:
op = Ld32x32bOp(
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
)
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
else:
raise ValueError()
def get_num_tmem_alloc_cols(
tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], rounding=True
) -> int:
"""Get the total number of TMEM allocation columns for the given TMEM tensors.
:param tmem_tensors: The TMEM tensors to get the number of allocation columns for.
:type tmem_tensors: Union[cute.Tensor, List[cute.Tensor]]
:param rounding: Whether to round up the number of allocation columns to the nearest power of 2.
:type rounding: bool
:return: The total number of TMEM allocation columns.
:rtype: int
:raises ValueError: If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32.
"""
# Turn tmem_tensors into a list
if isinstance(tmem_tensors, cute.Tensor):
tmem_tensors = [tmem_tensors]
# For each tensor in tmem_tensors, find the tmem_tensor_col_offset
num_tmem_alloc_cols_per_tensor = [
find_tmem_tensor_col_offset(t) for t in tmem_tensors
]
# Sum up the num_tmem_alloc_cols_per_tensor
num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor)
# Round up num_tmem_cols_total to the nearest power of 2
if rounding:
num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols))
# Validate the number of TMEM allocation columns
SM100_TMEM_CAPACITY_COLUMNS = 512
SM100_TMEM_MIN_ALLOC_COLUMNS = 32
if (
num_tmem_alloc_cols > SM100_TMEM_CAPACITY_COLUMNS
or num_tmem_alloc_cols < SM100_TMEM_MIN_ALLOC_COLUMNS
):
raise ValueError(
f"TMEM allocation columns {num_tmem_alloc_cols} exceeds the maximum capacity of {SM100_TMEM_CAPACITY_COLUMNS} or less than {SM100_TMEM_MIN_ALLOC_COLUMNS}"
)
return num_tmem_alloc_cols
def get_smem_layout_atom_ab(
major_mode: OperandMajorMode,
element_type: Type[Numeric],
smem_shape_mn_k: Tuple[int, int],
*,
loc=None,
ip=None,
) -> SmemLayoutAtomKind:
"""Simple heuristics to select the optimal SMEM layout atom based on the
majorness, the data type, and the major mode size.
:param major_mode: The major mode for the SMEM tensor is K major.
:type major_mode: OperandMajorMode
:param element_type: The element type for the SMEM tensor.
:type element_type: Type[Numeric]
:param smem_shape_mn_k: The shape of the SMEM tensor.
:type smem_shape_mn_k: Tuple[int, int]
:return: The SMEM layout atom kind
:rtype: SmemLayoutAtomKind
"""
is_k_major = major_mode == OperandMajorMode.K
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
sw32_num_contiguous_bits = 256
inter_num_contiguous_bits = 128
major_mode_size_bits = major_mode_size * element_type.width
assert major_mode_size_bits % inter_num_contiguous_bits == 0
if not is_k_major:
if (element_type.width == 32) and (
major_mode_size_bits % sw128_num_contiguous_bits == 0
):
return SmemLayoutAtomKind.MN_SW128_32B
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return SmemLayoutAtomKind.MN_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return SmemLayoutAtomKind.MN_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return SmemLayoutAtomKind.MN_SW32
return SmemLayoutAtomKind.MN_INTER
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return SmemLayoutAtomKind.K_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return SmemLayoutAtomKind.K_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return SmemLayoutAtomKind.K_SW32
return SmemLayoutAtomKind.K_INTER
@dsl_user_op
def make_smem_layout_a(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
a_dtype: Type[Numeric],
num_stages: int,
*,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
"""This function helps with:
1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler.
2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size.
3. cute.Tile the SMEM layout atom to the MMA tile shape.
4. Stage the SMEM layout based on the number of stages.
:param tiled_mma: The tiled MMA used to partition tensor A
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The MMA tile shape
:type mma_tiler_mnk: cute.cute.Tile
:param a_dtype: The element type for tensor A
:type a_dtype: Type[Numeric]
:param num_stages: The number of pipeline stages for tensor A
:type num_stages: int
:return: SMEM layout for tensor A
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
a_smem_shape = tiled_mma.partition_shape_A(
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
)
a_smem_shape_mn_k = (
cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
)
a_smem_layout_atom = make_smem_layout_atom(
get_smem_layout_atom_ab(
tiled_mma.op.a_major_mode,
a_dtype,
a_smem_shape_mn_k,
loc=loc,
ip=ip,
),
a_dtype,
loc=loc,
ip=ip,
)
a_smem_layout_staged = tile_to_mma_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, num_stages, loc=loc, ip=ip),
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
loc=loc,
ip=ip,
)
return a_smem_layout_staged
@dsl_user_op
def make_smem_layout_b(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
b_dtype: Type[Numeric],
num_stages: int,
*,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
"""This function helps:
1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler.
2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size.
3. cute.Tile the SMEM layout atom to the MMA tile shape.
4. Stage the SMEM layout based on the number of stages.
:param tiled_mma: The tiled MMA which is used to partition the B tensor.
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The MMA tile shape.
:type mma_tiler_mnk: cute.cute.Tile
:param b_dtype: The element type for the B tensor.
:type b_dtype: Type[Numeric]
:param num_stages: The stage of the B tensor.
:type num_stages: int
:return: SMEM layout for the B tensor.
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K
b_smem_shape = tiled_mma.partition_shape_B(
cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip)
)
b_smem_shape_nk = (
cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1],
cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2],
)
b_smem_layout_atom = make_smem_layout_atom(
get_smem_layout_atom_ab(
tiled_mma.op.b_major_mode,
b_dtype,
b_smem_shape_nk,
loc=loc,
ip=ip,
),
b_dtype,
loc=loc,
ip=ip,
)
b_smem_layout_staged = tile_to_mma_shape(
b_smem_layout_atom,
cute.append(b_smem_shape, num_stages, loc=loc, ip=ip),
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
loc=loc,
ip=ip,
)
return b_smem_layout_staged
@dsl_user_op
def get_smem_layout_atom_epi(
layout: LayoutEnum,
element_type: Type[Numeric],
epi_tile: cute.Tile,
*,
loc=None,
ip=None,
) -> SmemLayoutAtomKind:
"""Simple heuristics to select the optimal SMEM layout atom for epilog tensors.
:param layout: The layout enum for the SMEM tensor.
:type layout: LayoutEnum
:param element_type: The element type for the SMEM tensor.
:type element_type: Type[Numeric]
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.Tile
:return: The SMEM layout atom kind
:rtype: SmemLayoutAtomKind
"""
# Get the max contiguous tile usable by TMA
tma_shape = tuple(
(
# assumes get<0>(epi_tile) is coalesced and unit stride
cute.coalesce(cute.right_inverse(x, loc=loc, ip=ip), loc=loc, ip=ip).shape
if isinstance(x, cute.Layout)
else x
)
for x in epi_tile
)
if layout.is_m_major_c():
# ColMajor C/D (M-major)
return get_smem_layout_atom_ab(
OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip
)
else:
# RowMajor C/D (N-major)
return get_smem_layout_atom_ab(
OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip
)
@dsl_user_op
def make_smem_layout_epi(
epi_dtype: Type[Numeric],
epi_layout: LayoutEnum,
epi_tile: cute.Tile,
epi_stage: int,
*,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
"""This function helps:
1. Select the heuristic SMEM layout atom based on the epilog tile shape,
the epilog tensor's majorness, and the element type.
2. cute.Tile the SMEM layout atom to the epilog tile shape.
3. Stage the SMEM layout based on the number of stages.
:param epi_dtype: The element type for the epilog tensor.
:type epi_dtype: Type[Numeric]
:param epi_layout: The layout enum for the epilog tensor.
:type epi_layout: LayoutEnum
:param epi_tile: The epilogue tile shape.
:type epi_tile: cute.cute.Tile
:param epi_stage: The stage of the epilog tensor.
:type epi_stage: int
:return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog)
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
epilog_shape = cute.product_each(
cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip
)
c_smem_layout_atom = make_smem_layout_atom(
get_smem_layout_atom_epi(
epi_layout,
epi_dtype,
epi_tile,
loc=loc,
ip=ip,
),
epi_dtype,
loc=loc,
ip=ip,
)
epi_smem_layout_staged = cute.tile_to_shape(
c_smem_layout_atom,
cute.append(epilog_shape, epi_stage, loc=loc, ip=ip),
order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)),
loc=loc,
ip=ip,
)
return epi_smem_layout_staged
class SmemCapacity(Enum):
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
}
@dsl_user_op
def make_trivial_tiled_mma(
ab_dtype: Type[Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
acc_dtype: Type[Numeric],
cta_group: CtaGroup,
mma_tiler_mn: Tuple[int, int],
a_source: OperandSource = OperandSource.SMEM,
*,
loc=None,
ip=None,
) -> cute.TiledMma:
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
By default, the MMA atom is created with SMEM operand source for A.
:param ab_dtype: Data type of operands A and B.
:type ab_dtype: type[Numeric]
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
:type a_leading_mode: tcgen05.OperandMajorMode
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
:type b_leading_mode: tcgen05.OperandMajorMode
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[Numeric]
:param cta_group: The CTA group to use.
:type cta_group: tcgen05.CtaGroup
:param mma_tiler_mn: The shape (M, N, K) of the MMA tiler.
:type mma_tiler_mn: Tuple[int, int]
:param a_source: The source of operand A (SMEM by default or TMEM).
:type a_source: OperandSource
:return: A tiled MMA atom.
:rtype: cute.TiledMma
:raises TypeError: If the data type is not supported.
"""
if ab_dtype in {Float16, BFloat16}:
mma_op = MmaF16BF16Op(
ab_dtype,
acc_dtype,
(*mma_tiler_mn, 16),
cta_group,
a_source,
a_leading_mode,
b_leading_mode,
)
elif ab_dtype in {TFloat32, Float32}:
mma_op = MmaTF32Op(
(*mma_tiler_mn, 8),
cta_group,
a_source,
a_leading_mode,
b_leading_mode,
)
elif ab_dtype in {
Uint8,
Int8,
}:
mma_op = MmaI8Op(
ab_dtype,
(*mma_tiler_mn, 32),
cta_group,
a_source,
a_leading_mode,
b_leading_mode,
)
elif ab_dtype in {Float8E4M3FN, Float8E5M2}:
mma_op = MmaFP8Op(
ab_dtype,
acc_dtype,
(*mma_tiler_mn, 32),
cta_group,
a_source,
a_leading_mode,
b_leading_mode,
)
else:
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))

View File

@@ -0,0 +1,466 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import List, Tuple
import cutlass.cute as cute
from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values
from cutlass._mlir import ir
from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams
class GroupSearchResult:
"""
The result of the group search for grouped gemm.
:param group_idx: The result group index
:type group_idx: Int32
:param cta_tile_idx_m: CTA tile index along M dimension after rasterization
:type cta_tile_idx_m: Int32
:param cta_tile_idx_n: CTA tile index along N dimension after rasterization
:type cta_tile_idx_n: Int32
:param problem_shape_m: The M dimension of the gemm problem
:type problem_shape_m: Int32
:param problem_shape_n: The N dimension of the gemm problem
:type problem_shape_n: Int32
:param problem_shape_k: The K dimension of the gemm problem
:type problem_shape_k: Int32
:param cta_tile_count_k: Number of tiles along K dimension
:type cta_tile_count_k: Int32
"""
def __init__(
self,
group_idx: Int32,
cta_tile_idx_m: Int32,
cta_tile_idx_n: Int32,
problem_shape_m: Int32,
problem_shape_n: Int32,
problem_shape_k: Int32,
cta_tile_count_k: Int32,
) -> None:
self.group_idx = group_idx
self.cta_tile_idx_m = cta_tile_idx_m
self.cta_tile_idx_n = cta_tile_idx_n
self.problem_shape_m = problem_shape_m
self.problem_shape_n = problem_shape_n
self.problem_shape_k = problem_shape_k
self.cta_tile_count_k = cta_tile_count_k
def __extract_mlir_values__(self) -> List[ir.Value]:
values = extract_mlir_values(self.group_idx)
values.extend(extract_mlir_values(self.cta_tile_idx_m))
values.extend(extract_mlir_values(self.cta_tile_idx_n))
values.extend(extract_mlir_values(self.problem_shape_m))
values.extend(extract_mlir_values(self.problem_shape_n))
values.extend(extract_mlir_values(self.problem_shape_k))
values.extend(extract_mlir_values(self.cta_tile_count_k))
return values
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "GroupSearchResult":
assert len(values) == 7
return GroupSearchResult(*tuple(values))
class GroupedGemmGroupSearchState:
"""
The state of group index search for grouped gemm.
The state will be initialized once and updated in every round of group index search.
:param start_group_idx: The group idx to start the search with
:type start_group_idx: Int32
:param tile_count_prev_group: Number of tiles before the matched group
:type tile_count_prev_group: Int32
:param tile_count_searched: Number of tiles we have searched. When the matched group is found,
it records the number of tiles including the matched group
:type tile_count_searched: Int32
"""
def __init__(
self,
start_group_idx: Int32,
tile_count_prev_group: Int32,
tile_count_searched: Int32,
) -> None:
self.start_group_idx = start_group_idx
self.tile_count_prev_group = tile_count_prev_group
self.tile_count_searched = tile_count_searched
def __extract_mlir_values__(self) -> List[ir.Value]:
values = extract_mlir_values(self.start_group_idx)
values.extend(extract_mlir_values(self.tile_count_prev_group))
values.extend(extract_mlir_values(self.tile_count_searched))
return values
def __new_from_mlir_values__(
self, values: List[ir.Value]
) -> "GroupedGemmGroupSearchState":
start_group_idx = new_from_mlir_values(self.start_group_idx, [values[0]])
tile_count_prev_group = new_from_mlir_values(
self.tile_count_prev_group, [values[1]]
)
tile_count_searched = new_from_mlir_values(
self.tile_count_searched, [values[2]]
)
return GroupedGemmGroupSearchState(
start_group_idx, tile_count_prev_group, tile_count_searched
)
def create_initial_search_state() -> GroupedGemmGroupSearchState:
"""
Create an initial search state for grouped gemm.
:return: A new search state with initial values
:rtype: GroupedGemmGroupSearchState
"""
return GroupedGemmGroupSearchState(
start_group_idx=Int32(0),
tile_count_prev_group=Int32(0),
tile_count_searched=Int32(0),
)
class GroupedGemmTileSchedulerHelper:
"""
A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.
:param group_count: Number of groups in current grouped gemm problem
:type group_count: int
:param tile_sched_params: Parameter used to create the tile scheduler this helper works with
:type tile_sched_params: PersistentTileSchedulerParams
:param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k)
:type cluster_tile_shape_mnk: tuple[int, int, int]
:param search_state: The initial search state
:type search_state: GroupedGemmGroupSearchState
"""
def __init__(
self,
group_count: int,
tile_sched_params: PersistentTileSchedulerParams,
cluster_tile_shape_mnk: tuple[int, int, int],
search_state: GroupedGemmGroupSearchState,
) -> None:
self.tile_sched_params = tile_sched_params
self.group_count = group_count
self.lane_idx = cute.arch.lane_idx()
self.cluster_tile_shape_mnk = cluster_tile_shape_mnk
self.search_state = search_state
def __extract_mlir_values__(self) -> List[ir.Value]:
values = extract_mlir_values(self.tile_sched_params)
values.extend(extract_mlir_values(self.search_state))
return values
def __new_from_mlir_values__(
self, values: List[ir.Value]
) -> "GroupedGemmTileSchedulerHelper":
tile_sched_params = new_from_mlir_values(self.tile_sched_params, values)
search_state = new_from_mlir_values(self.search_state, values[1:])
return GroupedGemmTileSchedulerHelper(
self.group_count,
tile_sched_params,
self.cluster_tile_shape_mnk,
search_state,
)
def delinearize_z(
self,
cta_tile_coord: tuple,
problem_shape_mnkl: cute.Tensor,
) -> GroupSearchResult:
"""
Delinearize the linear z index and return GroupSearchResult.
This function should be used by warps that need to know the CTA tile index on M and N dimensions.
:param cta_tile_coord: The raw CTA coordinate from tile scheduler
:type cta_tile_coord: tuple of Int32
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for each group
:type problem_shape_mnkl: cute.Tensor
:return: The search result containing group index and tile coordinates
:rtype: GroupSearchResult
"""
# delinear the z coord
linear_idx = cta_tile_coord[2]
group_idx, problem_mnkl = self._group_search_and_load_problem_shape(
linear_idx,
problem_shape_mnkl,
self.search_state.start_group_idx,
self.search_state.tile_count_prev_group,
)
# linear index local to current group
cluster_tile_idx_in_current_group = (
linear_idx - self.search_state.tile_count_prev_group
)
cluster_count_m, cluster_count_n, cluster_count_k = cute.ceil_div(
(problem_mnkl[0], problem_mnkl[1], problem_mnkl[2]),
(
self.cluster_tile_shape_mnk[0],
self.cluster_tile_shape_mnk[1],
self.cluster_tile_shape_mnk[2],
),
)
# decompose to get indices on M and N
cta_tile_idx_m, cta_tile_idx_n = self._compute_cta_tile_coord(
cluster_tile_idx_in_current_group,
cta_tile_coord,
cluster_count_m,
cluster_count_n,
)
return GroupSearchResult(
group_idx,
cta_tile_idx_m,
cta_tile_idx_n,
problem_mnkl[0],
problem_mnkl[1],
problem_mnkl[2],
cluster_count_k,
)
def search_cluster_tile_count_k(
self,
cta_tile_coord: tuple,
problem_shape_mnkl: cute.Tensor,
) -> Tuple[Int32, Int32]:
"""
Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group.
This function should be used by warps that are only interested in the number of tiles along K dimension.
:param cta_tile_coord: The raw CTA coordinate from tile scheduler
:type cta_tile_coord: tuple of Int32
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
:type problem_shape_mnkl: cute.Tensor
:return: A tuple containing cluster count along K dimension and the group index
:rtype: Tuple[Int32, Int32]
"""
group_idx, problem_mnk = self._group_search_and_load_problem_shape(
cta_tile_coord[2],
problem_shape_mnkl,
self.search_state.start_group_idx,
self.search_state.tile_count_prev_group,
)
cluster_count_k = (
problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1
) // self.cluster_tile_shape_mnk[2]
return cluster_count_k, group_idx
@cute.jit
def _prefix_sum(self, value_per_thread: Int32) -> Int32:
"""
Perform prefix sum within a full warp.
:param value_per_thread: The value for this thread to contribute to the prefix sum
:type value_per_thread: Int32
:return: The prefix sum result for this thread
:rtype: Int32
"""
clamp_value = 0
idx = 1
sum_per_thread = value_per_thread
while idx < cute.arch.WARP_SIZE:
value = cute.arch.shuffle_sync_up(
sum_per_thread, idx, mask_and_clamp=clamp_value
)
if self.lane_idx >= idx:
sum_per_thread += value
idx = idx << 1
return sum_per_thread
def _get_problem_for_group(
self, problem_shape_mnkl: cute.Tensor, group_idx: Int32
) -> cute.Tensor:
"""
Load gemm problem (m,n,k,l) for the specified group from global memory to register.
:param problem_shape_mnkl: Tensor in global memory with layout (group_count, 4):(4, 1)
:type problem_shape_mnkl: cute.Tensor
:param group_idx: The index of the group to load
:type group_idx: Int32
:return: The problem shape tensor for the specified group
:rtype: cute.Tensor
"""
cur_problem_mnkl = cute.make_fragment(
cute.make_layout(4), problem_shape_mnkl.element_type
)
cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl)
return cur_problem_mnkl
def _get_cluster_tile_count_mn(self, problem_shape: cute.Tensor) -> Int32:
"""
Compute total cluster count.
:param problem_shape: Tensor containing problem shape (m, n, k, l)
:type problem_shape: cute.Tensor
:return: The total cluster tile count for M and N dimensions
:rtype: Int32
"""
cur_ntile_m = (
problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1
) // self.cluster_tile_shape_mnk[0]
cur_ntile_n = (
problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1
) // self.cluster_tile_shape_mnk[1]
cur_ntile_mn = cur_ntile_m * cur_ntile_n
return cur_ntile_mn
def _compute_cta_tile_coord(
self,
cluster_tile_idx: Int32,
cta_tile_coord_in_cluster: tuple,
cluster_tile_count_m: Int32,
cluster_tile_count_n: Int32,
) -> tuple:
"""
Compute CTA tile indices along M and N dimensions based on the linear index within a group.
It uses the AlongM mode to decompose the linear index onto M and N dimensions.
:param cluster_tile_idx: The linear index within a group
:type cluster_tile_idx: Int32
:param cta_tile_coord_in_cluster: CTA indices along M and N dimensions within a cluster
:type cta_tile_coord_in_cluster: tuple of Int32
:param cluster_tile_count_m: The number of clusters along M dimension of the matched group
:type cluster_tile_count_m: Int32
:param cluster_tile_count_n: The number of clusters along N dimension of the matched group
:type cluster_tile_count_n: Int32
:return: A tuple containing CTA tile indices along M and N dimensions
:rtype: tuple of (Int32, Int32)
"""
cluster_layout_mn = cute.make_layout(
(cluster_tile_count_m, cluster_tile_count_n)
)
(mi, ni) = cluster_layout_mn.get_hier_coord(cluster_tile_idx)
cta_tile_idx_m = (
mi * self.tile_sched_params.cluster_shape_mn[0]
+ cta_tile_coord_in_cluster[0]
)
cta_tile_idx_n = (
ni * self.tile_sched_params.cluster_shape_mn[1]
+ cta_tile_coord_in_cluster[1]
)
return (cta_tile_idx_m, cta_tile_idx_n)
@cute.jit
def _group_search(
self,
linear_idx: Int32,
problem_shape_mnkl: cute.Tensor,
init_group_idx: Int32,
init_tile_count_searched: Int32,
) -> GroupedGemmGroupSearchState:
"""
Search which group the linear index belongs to.
:param linear_idx: The linear index to be decomposed
:type linear_idx: Int32
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
:type problem_shape_mnkl: cute.Tensor
:param init_group_idx: The group idx to start the search with
:type init_group_idx: Int32
:param init_tile_count_searched: The number of tiles we have searched
:type init_tile_count_searched: Int32
:return: The updated search state
:rtype: GroupedGemmGroupSearchState
"""
c_0 = Int32(0).ir_value()
last_lane_idx = cute.arch.WARP_SIZE - 1
tile_count_searched = init_tile_count_searched
start_group_idx = init_group_idx
not_found = linear_idx >= tile_count_searched
tile_count_prev_group = self.search_state.tile_count_prev_group
while not_found:
# get group to search for current lane
cur_group_idx = start_group_idx + self.lane_idx
# check if the group to be checked is out of range
inside_group_bound = cur_group_idx < self.group_count
cur_ntile_mn = c_0
if inside_group_bound:
# get problem size of current group
cur_problem_mnkl = self._get_problem_for_group(
problem_shape_mnkl, cur_group_idx
)
cur_ntile_mn = self._get_cluster_tile_count_mn(cur_problem_mnkl)
# compute tile count from beginning to current group(included)
total_cluster_tile_count_ps_per_thread = self._prefix_sum(cur_ntile_mn)
cluster_tile_count_end_per_thread = (
total_cluster_tile_count_ps_per_thread + tile_count_searched
)
group_not_in_window = linear_idx >= cluster_tile_count_end_per_thread
hitted_group_idx_in_search_window = cute.arch.popc(
cute.arch.vote_ballot_sync(group_not_in_window)
)
not_found = hitted_group_idx_in_search_window == cute.arch.WARP_SIZE
start_group_idx = hitted_group_idx_in_search_window + start_group_idx
hit_the_1st_problem_in_search_window = (
hitted_group_idx_in_search_window == c_0
)
tile_count_prev_group = tile_count_searched
if hit_the_1st_problem_in_search_window == False:
tile_count_prev_group = cute.arch.shuffle_sync(
cluster_tile_count_end_per_thread,
hitted_group_idx_in_search_window - 1,
)
# If no matched group, then get new_cluster_tile_count_end from last lane
# Otherwise, get new_cluster_tile_count_end from the hitted group
lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window
if not_found:
lane_idx_for_cluster_tile_count_end = last_lane_idx
tile_count_searched = cute.arch.shuffle_sync(
cluster_tile_count_end_per_thread,
lane_idx_for_cluster_tile_count_end,
)
return GroupedGemmGroupSearchState(
start_group_idx,
tile_count_prev_group,
tile_count_searched,
)
def _group_search_and_load_problem_shape(
self,
linear_idx: Int32,
problem_shape_mnkl: cute.Tensor,
start_group_idx: Int32,
tile_count_searched: Int32,
) -> Tuple[Int32, cute.Tensor]:
"""
Perform group search and load problem shape for the matched group.
:param linear_idx: The linear index to be decomposed
:type linear_idx: Int32
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
:type problem_shape_mnkl: cute.Tensor
:param start_group_idx: The group idx to start the search with
:type start_group_idx: Int32
:param tile_count_searched: The number of tiles we have searched
:type tile_count_searched: Int32
:return: A tuple containing the final group index and the problem shape tensor
:rtype: Tuple[Int32, cute.Tensor]
"""
self.search_state = self._group_search(
linear_idx,
problem_shape_mnkl,
start_group_idx,
tile_count_searched,
)
# get final group search state
final_group_idx = self.search_state.start_group_idx
# let's revisit if it's better to broadcast problem_shape_mnk in group_search
problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx)
return final_group_idx, problem_mnkl

View File

@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cuda.bindings import driver, nvrtc
import cutlass.cute as cute
"""
This class is used to get the hardware info of given GPU device.
It provides methods to get the max active clusters for given cluster size.
Prerequisite:
- CUDA driver is initialized via `driver.cuInit` or other CUDA APIs.
- CUDA context is created via `driver.cuCtxCreate` or other CUDA APIs.
"""
class HardwareInfo:
"""
device_id: CUDA device ID to get the hardware info.
"""
def __init__(self, device_id: int = 0):
count = self._checkCudaErrors(driver.cuDeviceGetCount())
if device_id >= count:
raise ValueError(
f"Device ID {device_id} is out of range for device count {count}"
)
self.device_id = device_id
self.device = self._checkCudaErrors(driver.cuDeviceGet(device_id))
self.context = self._checkCudaErrors(driver.cuCtxGetCurrent())
self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion())
# Getting the max active clusters for a given cluster size
def get_max_active_clusters(self, cluster_size: int) -> int:
self._get_device_function()
if self._cuda_driver_version_lt(11, 8):
raise RuntimeError(
"CUDA Driver version < 11.8, cannot get _max_active_clusters"
)
if cluster_size <= 0 or cluster_size > 32:
raise ValueError(
f"Cluster size must be between 1 and 32, {cluster_size} is not supported"
)
max_shared_memory_per_block = self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
self.device,
)
)
self._checkCudaErrors(
driver.cuFuncSetAttribute(
self.kernel,
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
max_shared_memory_per_block,
)
)
max_dynamic_shared_memory = self._checkCudaErrors(
driver.cuOccupancyAvailableDynamicSMemPerBlock(
self.kernel, 1, 1 # numBlocks # blockSize
)
)
max_active_blocks = self._checkCudaErrors(
driver.cuOccupancyMaxActiveBlocksPerMultiprocessor(
self.kernel, 1, max_dynamic_shared_memory # blockSize,
)
)
# allow non-portable cluster size to support detection of non-portable cluster size
self._checkCudaErrors(
driver.cuFuncSetAttribute(
self.kernel,
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
1,
)
)
# prepare launch configuration
launch_config = driver.CUlaunchConfig()
launch_config.blockDimX = 128
launch_config.blockDimY = 1
launch_config.blockDimZ = 1
launch_config.sharedMemBytes = max_dynamic_shared_memory
launch_config.numAttrs = 1
# max possible cluster size is 32
cluster_dims_attr = driver.CUlaunchAttribute()
cluster_dims_attr.id = (
driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
)
value = driver.CUlaunchAttributeValue()
value.clusterDim.x = cluster_size
value.clusterDim.y = 1
value.clusterDim.z = 1
cluster_dims_attr.value = value
launch_config.attrs = [cluster_dims_attr]
launch_config.gridDimX = cluster_size
launch_config.gridDimY = max_active_blocks
launch_config.gridDimZ = 1
num_clusters = self._checkCudaErrors(
driver.cuOccupancyMaxActiveClusters(self.kernel, launch_config)
)
return num_clusters
def get_l2_cache_size_in_bytes(self) -> int:
return self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
self.device,
)
)
def get_device_multiprocessor_count(self) -> int:
return self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
self.device,
)
)
def _checkCudaErrors(self, result) -> None:
if result[0].value:
raise RuntimeError(
"CUDA error code={}({})".format(
result[0].value, self._cudaGetErrorEnum(result[0])
)
)
# CUDA APIs always return the status as the first element of the result tuple
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def _cudaGetErrorEnum(self, error) -> str:
if isinstance(error, driver.CUresult):
err, name = driver.cuGetErrorName(error)
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
def _cuda_driver_version_ge(self, major: int, minor: int) -> bool:
return self.driver_version >= (major * 1000 + 10 * minor)
def _cuda_driver_version_lt(self, major: int, minor: int) -> bool:
return not self._cuda_driver_version_ge(major, minor)
@cute.kernel
def _empty_kernel(self):
return
@cute.jit
def _host_function(self):
self._empty_kernel().launch(
grid=[1, 1, 1],
block=[1, 1, 1],
)
# get a empty kernel to compute occupancy
def _get_device_function(self) -> None:
self.compiled_kernel = cute.compile(self._host_function)
self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module
self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr

View File

@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Type, Tuple
from enum import Enum
from cutlass.utils.layout import LayoutEnum
from cutlass.cutlass_dsl import (
Float16,
BFloat16,
Float8E5M2,
Float8E4M3FN,
Numeric,
NumericMeta,
dsl_user_op,
)
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu.common import CopyUniversalOp
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp
from cutlass.cute.nvgpu.warpgroup import (
MmaF16BF16Op,
MmaF8Op,
OperandMajorMode,
OperandSource,
)
@dsl_user_op
def sm90_get_smem_store_op(
layout_d: LayoutEnum,
elem_ty_d: Type[Numeric],
elem_ty_acc: Type[Numeric],
*,
loc=None,
ip=None,
) -> cute.CopyAtom:
"""
Selects the largest vectorized smem store atom available subject to constraint of gmem layout.
Parameters:
-----------
layout_d : LayoutEnum
The layout enum of the output tensor D.
elem_ty_d : Type[Numeric]
The element type for output tensor D.
elem_ty_acc : Type[Numeric]
The element type for accumulator.
Returns:
--------
Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
"""
def validate_type(ty, ty_name):
if not isinstance(ty, NumericMeta):
raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")
validate_type(elem_ty_d, "elem_ty_d")
validate_type(elem_ty_acc, "elem_ty_acc")
is_m_major = layout_d.is_m_major_c()
if elem_ty_d.width == 16:
return cute.make_copy_atom(
StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip
)
else:
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
class SmemCapacity(Enum):
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
}
def make_trivial_tiled_mma(
a_dtype: Type[Numeric],
b_dtype: Type[Numeric],
a_leading_mode: OperandMajorMode,
b_leading_mode: OperandMajorMode,
acc_dtype: Type[Numeric],
atom_layout_mnk: Tuple[int, int, int],
tiler_mn: Tuple[int, int],
) -> cute.TiledMma:
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
By default, the MMA atom is created with SMEM operand source for A.
:param a_dtype: Data type of operand A.
:type a_dtype: type[Numeric]
:param b_dtype: Data type of operand B.
:type b_dtype: type[Numeric]
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
:type a_leading_mode: warpgroup.OperandMajorMode
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
:type b_leading_mode: warpgroup.OperandMajorMode
:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[Numeric]
:param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads.
:type atom_layout_mnk: Tuple[int, int, int]
:param tiler_mn: The shape (M, N) of the cta tiler.
:type tiler_mn: Tuple[int, int]
:return: A tiled MMA atom.
:rtype: cute.TiledMma
:raises TypeError: If the data type is not supported.
"""
if a_dtype in {Float16, BFloat16}:
if cutlass.const_expr(a_dtype != b_dtype):
raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}")
if cutlass.const_expr(a_dtype.width != b_dtype.width):
raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}")
mma_op = MmaF16BF16Op(
a_dtype,
acc_dtype,
(*tiler_mn, 16),
OperandSource.SMEM,
a_leading_mode,
b_leading_mode,
)
elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in {
Float8E4M3FN,
Float8E5M2,
}:
mma_op = MmaF8Op(
a_dtype,
b_dtype,
acc_dtype,
(*tiler_mn, 32),
OperandSource.SMEM,
a_leading_mode,
b_leading_mode,
)
else:
raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}")
return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk)
def get_smem_layout_atom(
layout: LayoutEnum,
element_type: Type[Numeric],
major_mode_size: int,
*,
loc=None,
ip=None,
):
"""Select the optimal shared memory layout atom based on parameters.
:param layout: Layout enum of the tensor
:type layout: LayoutEnum
:param element_type: Data type of the elements
:type element_type: type[cutlass.Numeric]
:param major_mode_size: Size of the major mode dimension
:type major_mode_size: int
:return: Selected shared memory layout atom kind
:rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind
"""
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
sw32_num_contiguous_bits = 256
major_mode_size_bits = major_mode_size * element_type.width
if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER

View File

@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from enum import Enum
import cutlass.cute as cute
from cutlass.cute.nvgpu import warpgroup
from cutlass.cute.nvgpu import tcgen05
class LayoutEnum(Enum):
ROW_MAJOR = "row_major"
COL_MAJOR = "col_major"
def mma_major_mode(self):
return (
tcgen05.OperandMajorMode.K
if self == LayoutEnum.ROW_MAJOR
else tcgen05.OperandMajorMode.MN
)
def sm90_mma_major_mode(self):
return (
warpgroup.OperandMajorMode.K
if self == LayoutEnum.ROW_MAJOR
else warpgroup.OperandMajorMode.MN
)
def is_k_major_a(self):
return self == LayoutEnum.ROW_MAJOR
def is_m_major_a(self):
return self == LayoutEnum.COL_MAJOR
def is_k_major_b(self):
return self == LayoutEnum.COL_MAJOR
def is_n_major_b(self):
return self == LayoutEnum.ROW_MAJOR
def is_n_major_c(self):
return self == LayoutEnum.ROW_MAJOR
def is_m_major_c(self):
return self == LayoutEnum.COL_MAJOR
@staticmethod
def from_tensor(tensor: cute.Tensor) -> "LayoutEnum":
ret = None
if tensor.leading_dim == 1:
ret = LayoutEnum.ROW_MAJOR
elif tensor.leading_dim == 0:
ret = LayoutEnum.COL_MAJOR
else:
raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}")
return ret
__all__ = ["LayoutEnum"]

View File

@@ -0,0 +1,984 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from cutlass.cutlass_dsl import Boolean, Int32, Int64, T, if_generate, and_, or_
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass.cute as cute
##############################################################################
# Agent class
##############################################################################
class Agent(enum.Enum):
"""
Agent indicates what is participating in the pipeline synchronization.
"""
# Arbitrary grouping of N threads
Thread = enum.auto()
# Same as AsyncThread, but includes all threads in the block
ThreadBlock = enum.auto()
# Same as AsyncThread, but includes all threads in the cluster
ThreadBlockCluster = enum.auto()
class CooperativeGroup:
"""
CooperativeGroup contains size and alignment restrictions for an Agent.
"""
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
if agent is Agent.Thread:
assert size > 0
if size == 32:
assert (
size == alignment
), "Error: Alignment does not match number of threads in a warp."
elif size == 128:
assert (
size == alignment
), "Error: Alignment does not match number of threads in a warpgroup."
elif agent is Agent.ThreadBlock:
assert False, "Error: Not yet supported."
elif agent is Agent.ThreadBlockCluster:
assert False, "Error: Not yet supported."
else:
# Should never reach this state
size = 0
if size <= 0:
raise ValueError(
"Error: The number of threads in a CooperativeGroup must be more than 0."
)
# Size indicates how many threads are participating in this CooperativeGroup
self.size = size
# Agent indicates the type of thread group
self.agent = agent
class _PipelineOp(enum.Enum):
"""
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
"""
# async-threads
AsyncThread = enum.auto()
# Blackwell (SM100a) MMA instruction
TCGen05Mma = enum.auto()
# Tensor Memory Accelerator load
TmaLoad = enum.auto()
# TMA Store consuming smem produced by AsyncThread
TmaStore = enum.auto()
def _get_pipeline_op(type_str):
return _PipelineOp(type_str)
##############################################################################
# SyncObjectArray class
##############################################################################
class SyncObjectArray(ABC):
"""
SyncObjectArray is an abstract base class for different types of hardware synchronizations (e.g. smem barriers, named barriers, fences)
"""
@abstractmethod
def wait(self):
pass
@abstractmethod
def arrive(self):
pass
@abstractmethod
def get_barrier(self):
pass
class MbarrierArray(SyncObjectArray):
"""
MbarrierArray implements an abstraction for an array of smem barriers.
"""
def __init__(
self,
barrier_storage: cute.Pointer,
num_stages: int,
agent: tuple[_PipelineOp, CooperativeGroup],
tx_count: int = 0,
):
self.barrier_storage = barrier_storage
self.tx_count = tx_count
self.num_stages = num_stages
self.op_type, self.cg = agent
self.arrive_count = self.cg.size
if self.num_stages <= 0:
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
if self.arrive_count <= 0:
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
if self.op_type is _PipelineOp.TmaLoad and self.tx_count <= 0:
raise ValueError(
"Error: Mbarrier tx count must be greater than 0 for TMA ops."
)
# Using a tensor to store mbarrier i64 ptrs
self.mbarrier_array = cute.make_fragment(cute.make_layout(num_stages), Int64)
for i in range(num_stages):
self.mbarrier_array[i] = _cute_ir.ptrtoint(
T.i64(), (self.barrier_storage + i).value
)
# Mbarrier initialization in constructor
self.mbarrier_init()
# Mbarrier initialization
def mbarrier_init(self):
"""
Initializes an array of mbarriers using warp 0.
"""
def then_body():
for index in range(self.num_stages):
cute.arch.mbarrier_init_arrive_cnt(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), self.arrive_count
)
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
if_generate(warp_idx == 0, then_body)
def arrive(self, index: int, dst: int):
"""
Select the arrive corresponding to this MbarrierArray's PipelineOp
:param index: Index of the mbarrier in the array to arrive on
:type index: int
:param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. When None, both TCGen05Mma and AsyncThread will arrive on their local mbarrier.
- For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3).
- For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster).
:type dst: int | None
"""
if self.op_type is _PipelineOp.AsyncThread:
self.arrive_mbarrier(index, dst)
elif self.op_type is _PipelineOp.TCGen05Mma:
self.arrive_tcgen05mma(index, dst)
elif self.op_type in [_PipelineOp.TmaLoad]:
self.arrive_and_expect_tx(index, self.tx_count)
else:
print(_get_pipeline_op(self.op_type))
assert False, "Error: MbarrierArray is not supported for this PipelineOp."
def arrive_mbarrier(self, index: int, dst_rank: int):
if dst_rank is None:
cute.arch.mbarrier_arrive(_mbarrier_i64_to_ptr(self.mbarrier_array[index]))
else:
cute.arch.mbarrier_arrive(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), dst_rank
)
def arrive_tcgen05mma(self, index: int, mask: int):
if mask is None:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(
_mbarrier_i64_to_ptr(self.mbarrier_array[index])
)
else:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]),
mask,
cute.nvgpu.tcgen05.CtaGroup.TWO,
)
def arrive_and_expect_tx(self, index: int, tx_count: int):
with cute.arch.elect_one():
cute.arch.mbarrier_init_tx_bytes(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), tx_count
)
def try_wait(self, index: int, phase: int):
return cute.arch.mbarrier_try_wait(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase
)
def wait(self, index: int, phase: int):
cute.arch.mbarrier_wait(_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase)
def get_barrier(self, index: int) -> cute.Pointer:
return _mbarrier_i64_to_ptr(self.mbarrier_array[index])
class TmaStoreFence(SyncObjectArray):
"""
TmaStoreFence is used for a multi-stage epilogue buffer.
"""
def __init__(
self,
num_stages: int = 0,
):
if num_stages <= 0:
raise ValueError("Mbarrier stage count must be greater than 0.")
self.num_stages = num_stages
def arrive(self):
cute.arch.cp_async_bulk_commit_group()
def wait(self):
cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
# TmaStoreFence doesn't have mbarriers
def get_barrier(self):
assert (
False
), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier."
def tail(self):
cute.arch.cp_async_bulk_wait_group(0, read=True)
##############################################################################
# PipelineState class
##############################################################################
class PipelineUserType(enum.Enum):
Producer = enum.auto()
Consumer = enum.auto()
class PipelineState:
"""
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
"""
def __init__(self, stages: int, count, index, phase):
self._stages = stages
self._count = count
self._index = index
self._phase = phase
def clone(self) -> "PipelineState":
return PipelineState(self.stages, self._count, self.index, self.phase)
@property
def index(self) -> Int32:
return self._index
@property
def count(self) -> Int32:
return self._count
@property
def stages(self) -> int:
return self._stages
@property
def phase(self) -> Int32:
return self._phase
def reset_count(self):
self._count = Int32(0)
def advance(self):
self._index += 1
self._count += 1
def then_body(index, phase):
new_index = Int32(0)
new_phase = phase ^ 1
return new_index, new_phase
def else_body(index, phase):
return index, phase
self._index, self._phase = if_generate(
self._index == self.stages,
then_body,
else_body,
[self.index, self.phase],
[Int32, Int32],
)
def reverse(self):
self._index -= 1
self._count -= 1
def then_body(index, phase):
new_index = Int32(self.stages - 1)
new_phase = phase ^ 1
return new_index, new_phase
def else_body(index, phase):
return index, phase
self._index, self._phase = if_generate(
self._index == -1,
then_body,
else_body,
[self.index, self.phase],
[Int32, Int32],
)
def __get_mlir_types__(self):
return [self._count.type, self._index.type, self._phase.type]
def __extract_mlir_values__(self):
count = self._count
index = self._index
phase = self._phase
return [count.ir_value(), index.ir_value(), phase.ir_value()]
# This can be overridden by derived classes
def __new_from_mlir_values__(self, values):
return PipelineState(
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
)
def make_pipeline_state(type: PipelineUserType, stages: int):
"""
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
"""
if type is PipelineUserType.Producer:
return PipelineState(
stages,
Int32(0),
Int32(0),
Int32(1),
)
elif type is PipelineUserType.Consumer:
return PipelineState(
stages,
Int32(0),
Int32(0),
Int32(0),
)
else:
assert (
False
), "Error: invalid PipelineUserType specified for make_pipeline_state."
##############################################################################
# Pipeline classes
##############################################################################
@dataclass(frozen=True)
class PipelineAsync:
"""
PipelineAsync is a generic pipeline class where both the producer and consumer are
AsyncThreads. It also serves as a base class for specialized pipeline classes.
"""
sync_object_array_full: SyncObjectArray
sync_object_array_empty: SyncObjectArray
num_stages: Int32
producer_mask: Int32
consumer_mask: Int32
@staticmethod
def _make_sync_object_array(
barrier_storage: cute.Pointer,
num_stages: Int32,
agent: tuple[_PipelineOp, CooperativeGroup],
tx_count: int = 0,
) -> SyncObjectArray:
"""
Returns a SyncObjectArray corresponding to an agent's PipelineOp.
"""
if agent[0] in [
_PipelineOp.AsyncThread,
_PipelineOp.TmaLoad,
_PipelineOp.TCGen05Mma,
]:
return MbarrierArray(
barrier_storage=barrier_storage,
num_stages=num_stages,
agent=agent,
tx_count=tx_count,
)
elif agent[0] is _PipelineOp.TmaStore:
# Path taken for AsyncTmaStore
return TmaStoreFence(num_stages=num_stages)
else:
assert False, "Error: Invalid PipelineOp specified."
@staticmethod
def create(
barrier_storage: cute.Pointer,
num_stages: Int32,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
producer_mask: Int32 = None,
consumer_mask: Int32 = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param producer_mask: Mask for signaling arrives for the producer agent
:type producer_mask: Int32 | None
:param consumer_mask: Mask for signaling arrives for the consumer agent
:type consumer_mask: Int32 | None
"""
producer_type = _PipelineOp.AsyncThread
consumer_type = _PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_array_full = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8), num_stages, producer
)
sync_object_array_empty = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
pipeline_init_wait()
return PipelineAsync(
sync_object_array_full,
sync_object_array_empty,
num_stages,
producer_mask,
consumer_mask,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
)
def producer_try_acquire(self, state: PipelineState):
return self.sync_object_array_empty.try_wait(state.index, state.phase)
def producer_commit(self, state: PipelineState):
self.sync_object_array_full.arrive(state.index, self.producer_mask)
def consumer_wait(
self, state: PipelineState, try_wait_token: Optional[Boolean] = None
):
if_generate(
try_wait_token is None or try_wait_token == 0,
lambda: self.sync_object_array_full.wait(state.index, state.phase),
)
def consumer_try_wait(self, state: PipelineState):
return self.sync_object_array_full.try_wait(state.index, state.phase)
def consumer_release(self, state: PipelineState):
self.sync_object_array_empty.arrive(state.index, self.consumer_mask)
def producer_get_barrier(self, state: PipelineState) -> cute.Pointer:
return self.sync_object_array_full.get_barrier(state.index)
def producer_tail(self, state: PipelineState):
"""
Make sure the last used buffer empty signal is visible to producer.
Producer tail is usually executed by producer before exit, to avoid dangling
mbarrier arrive signals after kernel exit.
:param state: The pipeline state that points to next useful buffer
:type state: PipelineState
"""
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
for i in range(self.num_stages - 1):
state.advance()
self.producer_acquire(state)
@dataclass(frozen=True)
class PipelineTmaAsync(PipelineAsync):
"""
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
"""
is_signalling_thread: bool
@staticmethod
def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout):
"""
Initialize the empty barrier arrive signal
This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread
"""
# Logic to optimally schedule Empty Arrives
cluster_shape_mnk = cta_layout_vmnk.shape
tidx, _, _ = cute.arch.thread_idx()
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
is_signalling_thread = tidx < cute.size(cluster_shape_mnk)
dst_rank = tidx % cute.size(cluster_shape_mnk)
m = cluster_shape_mnk[0]
# Check if same row
is_same_row_l = dst_rank % m
is_same_row_r = cta_rank_in_cluster % m
is_same_row = is_same_row_l == is_same_row_r
# Check if same column
is_same_col_l = dst_rank // m
is_same_col_r = cta_rank_in_cluster // m
is_same_col = is_same_col_l == is_same_col_r
is_same_row_or_col = or_(is_same_row, is_same_col)
is_signalling_thread_final = and_(is_signalling_thread, is_same_row_or_col)
return dst_rank, is_signalling_thread_final
@staticmethod
def create(
barrier_storage: cute.Pointer,
num_stages: Int32,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
tx_count: int,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
producer_type = _PipelineOp.TmaLoad
consumer_type = _PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_array_full = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_array_empty = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
dst_rank, is_signalling_thread = (
PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk)
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
dst_rank = None
else:
dst_rank = dst_rank
is_signalling_thread = is_signalling_thread
producer_mask = None
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaAsync(
sync_object_array_full,
sync_object_array_empty,
num_stages,
producer_mask,
dst_rank,
is_signalling_thread,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
)
self.sync_object_array_full.arrive(state.index, self.producer_mask)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
"""
pass
def consumer_release(self, state: PipelineState):
"""
TMA consumer release conditionally signals the empty buffer to the producer.
"""
if_generate(
self.is_signalling_thread,
lambda: self.sync_object_array_empty.arrive(
state.index, self.consumer_mask
),
)
@dataclass(frozen=True)
class PipelineTmaUmma(PipelineAsync):
"""
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
"""
is_leader_cta: bool
@staticmethod
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask for signaling arrivals to multicasting threadblocks.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
)
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
)
block_in_cluster_coord_vmnk_peer = (
cta_in_cluster_coord_vmnk[0] ^ 1,
*cta_in_cluster_coord_vmnk[1:],
)
tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2
)
tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1
)
return (
tma_mcast_mask_a
| tma_mcast_mask_b
| tma_mcast_mask_a_peer
| tma_mcast_mask_b_peer
)
@staticmethod
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
"""
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
"""
bidx, bidy, _ = cute.arch.block_idx()
mma_coord_vmnk = (
bidx % cute.size(cta_layout_vmnk, mode=[0]),
bidx // cute.size(cta_layout_vmnk, mode=[0]),
bidy,
None,
)
return mma_coord_vmnk[0] == 0
@staticmethod
def create(
barrier_storage: cute.Pointer,
num_stages: Int32,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
tx_count: int,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
producer_type = _PipelineOp.TmaLoad
consumer_type = _PipelineOp.TCGen05Mma
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_array_full = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8), num_stages, producer, tx_count
)
sync_object_array_empty = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# No mcast mask if not using clusters
producer_mask = None
# All threadblocks are leaders if not using clusters
is_leader_cta = True
else:
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
consumer_mask = producer_mask
pipeline_init_wait(cta_layout_vmnk)
return PipelineTmaUmma(
sync_object_array_full,
sync_object_array_empty,
num_stages,
producer_mask,
consumer_mask,
is_leader_cta,
)
def producer_acquire(
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
):
"""
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
"""
if_generate(
try_acquire_token is None or try_acquire_token == 0,
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
)
if_generate(
self.is_leader_cta,
lambda: self.sync_object_array_full.arrive(state.index, self.producer_mask),
)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
"""
pass
@dataclass(frozen=True)
class PipelineUmmaAsync(PipelineAsync):
"""
PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
"""
@staticmethod
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask to signal completion of tmem buffers for 2CTA kernels.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
return cute.make_layout_image_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0
)
@staticmethod
def _compute_peer_cta_rank():
"""
Computes a mask to signal release of tmem buffers for 2CTA kernels.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
return cta_rank_in_cluster // 2 * 2
@staticmethod
def create(
barrier_storage: cute.Pointer,
num_stages: Int32,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
cta_layout_vmnk: Optional[cute.Layout] = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
"""
producer_type = _PipelineOp.TCGen05Mma
consumer_type = _PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_array_full = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8), num_stages, producer
)
sync_object_array_empty = PipelineAsync._make_sync_object_array(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# Set mask to None if not using clusters (i.e. 1CTA kernels)
producer_mask = None
else:
producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
# Set mask to None if not using 2CTA intructions
consumer_mask = None
else:
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
pipeline_init_wait(cta_layout_vmnk)
return PipelineUmmaAsync(
sync_object_array_full,
sync_object_array_empty,
num_stages,
producer_mask,
consumer_mask,
)
def producer_tail(self, state: PipelineState):
"""
Make sure the last used buffer empty signal is visible to producer.
Producer tail is usually executed by producer before exit, to avoid dangling
mbarrier arrive signals after kernel exit.
:param state: The pipeline state that points to next useful buffer
:type state: PipelineState
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
is_leader_cta = cta_rank_in_cluster % 2 == 0
def then_body():
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
for i in range(self.num_stages - 1):
state.advance()
self.producer_acquire(state)
if_generate(is_leader_cta, then_body)
@dataclass(frozen=True)
class PipelineTmaStore(PipelineAsync):
"""
PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers.
"""
@staticmethod
def create(
num_stages: Int32,
producer_group: CooperativeGroup,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
"""
producer_type = _PipelineOp.TmaStore
producer = (producer_type, producer_group)
sync_object_array_full = PipelineAsync._make_sync_object_array(
None, num_stages, producer
)
return PipelineTmaStore(sync_object_array_full, None, num_stages, None, None)
def producer_acquire(self):
self.sync_object_array_full.wait()
def producer_commit(self):
self.sync_object_array_full.arrive()
def consumer_wait(self):
assert False, "Error: PipelineTmaStore does not have a consumer agent."
def consumer_release(self):
assert False, "Error: PipelineTmaStore does not have a consumer agent."
def producer_tail(self):
self.sync_object_array_full.tail()
##############################################################################
# Helper functions
##############################################################################
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
"""
Fences the mbarrier init and syncs the threadblock or cluster
"""
cute.arch.mbarrier_init_fence()
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
# If not using clusters, sync the threadblock
_sync(Agent.ThreadBlock)
else:
# If using clusters, sync the cluster
_sync(Agent.ThreadBlockCluster)
def _sync(group: Agent):
"""
Syncs all threads within an agent.
"""
if group is Agent.Thread:
assert False, "Error: Not supported."
elif group is Agent.ThreadBlock:
cute.arch.sync_threads()
elif group is Agent.ThreadBlockCluster:
cute.arch.cluster_arrive()
cute.arch.cluster_wait()
else:
assert (
False
), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer:
"""
Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment
"""
return cute.make_ptr(
Int64,
val.ir_value(),
mem_space=_cute_ir.AddressSpace.smem,
assumed_align=8,
)

View File

@@ -0,0 +1,217 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Type, Union, overload
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
import cutlass.cute as cute
from cutlass.cute.arch import get_dyn_smem
class SmemAllocator:
"""
A class for managing shared memory allocation on GPU.
This class manages a chunk of shared memory and provide APIs for sub-allocation
inside the chunk.
Attributes
----------
_base : cute.Pointer as i8 typed dynamic value
The current base address of the shared memory.
_allocated_bytes:
The bytes allocated in shared memory.
Methods
-------
allocate(num_bytes, alignment)
Allocates num_bytes in the shared memory with the given byte alignment.
allocate_value(value_ty, num_elems)
Allocates num_elems of value_ty values in the shared memory.
allocate_tensor(value_ty, layout, alignment)
Allocates a tensor in the shared memory with given layout and byte alignment.
Notes
-----
This class is responsible for managing the allocation of tensors in shared memory.
"""
def __init__(self):
"""
Initializes the SmemAllocator instance with dynamic smem base ptr,
which is i8 type and aligned to 1024.
"""
self._base = get_dyn_smem(Int8, alignment=1024)
self._allocated_bytes = 0
@overload
def allocate(self, size_or_type: int, byte_alignment: int): ...
@overload
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
"""
Allocates a block of memory with the specified size and byte alignment.
This method adjusts the base cute.Pointer to ensure that the allocated memory
is aligned according to the specified byte alignment. It updates the internal
state to reflect the new base cute.Pointer and the total allocated bytes.
Parameters
----------
size_or_type : int or struct
The number of bytes to allocate or struct class.
byte_alignment : int
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
Returns
----------
A cute.Pointer to the start of the allocated memory block or struct instance.
Raises
----------
ValueError
If num_bytes is negative or if byte_alignmemt is less than 1.
"""
if isinstance(size_or_type, cute.struct):
alignment = max(byte_alignment, size_or_type.__alignof__())
base_ptr = self.allocate(size_or_type.__sizeof__(), alignment)
return size_or_type(base_ptr)
num_bytes = size_or_type
if num_bytes < 0:
raise ValueError("num_bytes must be non-negative")
if byte_alignment < 1:
raise ValueError("byte_alignment must be at least 1")
self._base = self._base.align(byte_alignment)
ptr = self._base
self._base += num_bytes
if self._allocated_bytes % byte_alignment != 0:
self._allocated_bytes += (
byte_alignment - self._allocated_bytes % byte_alignment
)
self._allocated_bytes += num_bytes
return ptr
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):
"""
Allocates num_elems values of element_type in shared memory.
This method calls allocate() to return a byte ptr, pointing to start of shared
memory. Then calls cute.recast_ptr() to recast this byte cute.Pointer to element_type.
Parameters
----------
element_type : Type[Numeric]
The type of the values in the tensor.
num_elems : int, optional
The number of elements for each allocation. Defaults to 1.
Returns
----------
A value_type cute.Pointer to the start of the allocated memory block.
Raises
----------
ValueError
If num_elems is less than 1.
"""
if num_elems < 1:
raise ValueError("num_elems must be at least 1")
if not isinstance(element_type, NumericMeta):
raise TypeError(
f"value_ty must be a type of Numeric, but got {element_type}"
)
ptr = self.allocate(
element_type.width // 8 * num_elems, element_type.width // 8
)
return cute.recast_ptr(ptr, dtype=element_type)
def allocate_tensor(
self,
element_type: Type[Numeric],
layout: Union[int, cute.Layout, cute.ComposedLayout],
byte_alignment: int = 1,
swizzle: cute.Swizzle = None,
):
"""
Allocates a tensor in the shared memory with value type, layout and byte alignment.
Parameters
----------
element_type : Type[Numeric]
The type of the values in the tensor.
layout : int | DynamicInt | cute.Layout | cute.ComposedLayout
The layout of the tensor.
byte_alignment : int, optional
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
swizzle : cute.Swizzle
A swizzle for the iterator (for position-dependent swizzling).
Returns
-------
tensor : cute.Tensor
The allocated tensor with specified value type, layout and byte alignment.
Notes
-----
The base address is updated to point to the next available memory location.
"""
if not isinstance(element_type, NumericMeta):
raise TypeError(
f"value_ty must be a type of Numeric, but got {element_type}"
)
if (
isinstance(layout, cute.ComposedLayout)
and isinstance(layout.inner, cute.Swizzle)
) and (swizzle is not None):
raise TypeError(
f"iterator swizzle with swizzle layout is currently not supported"
)
if isinstance(layout, int):
layout = cute.make_layout(layout)
profile = layout(0)
if isinstance(profile, tuple):
raise TypeError(
f"cannot allocate a shared memory tensor with a non-integer iterator"
)
if not cute.is_static(layout.type):
raise NotImplementedError(f"dynamic layout is not supported: {layout.type}")
# At least align the allocation to the natural alignment given by the element type
if element_type.width // 8 > byte_alignment:
byte_alignment = element_type.width // 8
# Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned
cosize_in_bits = cute.cosize(layout) * element_type.width
assert isinstance(cosize_in_bits, int)
if cosize_in_bits % 8 != 0:
raise ValueError("invalid allocation that is not byte-aligned")
num_bytes = cosize_in_bits // 8
ptr = self.allocate(num_bytes, byte_alignment)
ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type)
res = cute.make_tensor(ptr, layout)
return res

View File

@@ -0,0 +1,384 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Tuple
from cutlass.cutlass_dsl import (
Boolean,
Integer,
Int32,
min,
extract_mlir_values,
new_from_mlir_values,
dsl_user_op,
)
from cutlass._mlir import ir
import cutlass.cute as cute
##############################################################################
# Static persistent tile scheduler
##############################################################################
class WorkTileInfo:
"""A class to represent information about a work tile.
:ivar tile_idx: The index of the tile.
:type tile_idx: cute.Coord
:ivar is_valid_tile: Whether the tile is valid.
:type is_valid_tile: Boolean
"""
def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean):
self._tile_idx = tile_idx
self._is_valid_tile = Boolean(is_valid_tile)
def __extract_mlir_values__(self) -> list[ir.Value]:
values = extract_mlir_values(self.tile_idx)
values.extend(extract_mlir_values(self.is_valid_tile))
return values
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
assert len(values) == 4
new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1])
new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]])
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
@property
def is_valid_tile(self) -> Boolean:
"""Check latest tile returned by the scheduler is valid or not. Any scheduling
requests after all tasks completed will return an invalid tile.
:return: The validity of the tile.
:rtype: Boolean
"""
return self._is_valid_tile
@property
def tile_idx(self) -> cute.Coord:
"""
Get the index of the tile.
:return: The index of the tile.
:rtype: cute.Coord
"""
return self._tile_idx
class PersistentTileSchedulerParams:
"""A class to represent parameters for a persistent tile scheduler.
This class is designed to manage and compute the layout of clusters and tiles
in a batched gemm problem.
:ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).
:type cluster_shape_mn: tuple
:ivar problem_layout_ncluster_mnl: Layout of the problem in terms of
number of clusters in (m, n, l) dimensions.
:type problem_layout_ncluster_mnl: cute.Layout
"""
def __init__(
self,
problem_shape_ntile_mnl: cute.Shape,
cluster_shape_mnk: cute.Shape,
*,
loc=None,
ip=None,
):
"""
Initializes the PersistentTileSchedulerParams with the given parameters.
:param problem_shape_ntile_mnl: The shape of the problem in terms of
number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.
:type problem_shape_ntile_mnl: cute.Shape
:param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions.
:type cluster_shape_mnk: cute.Shape
:raises ValueError: If cluster_shape_k is not 1.
"""
if cluster_shape_mnk[2] != 1:
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
# cluster_shape_mnk is kept for reconstruction
self._cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mnk[:2]
self._loc = loc
# By default, we follow m major (col-major) raster order, so make a col-major layout
self.problem_layout_ncluster_mnl = cute.make_layout(
cute.ceil_div(
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
),
loc=loc,
ip=ip,
)
def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]:
obj_values = extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values
def __new_from_mlir_values__(self, values):
obj_list = []
for obj, n_items in zip(
[self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos
):
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
values = values[n_items:]
return PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
@dsl_user_op
def get_grid_shape(
self, max_active_clusters: Int32, *, loc=None, ip=None
) -> Tuple[Integer, Integer, Integer]:
"""
Computes the grid shape based on the maximum active clusters allowed.
:param max_active_clusters: The maximum number of active clusters that
can run in one wave.
:type max_active_clusters: Int32
:return: A tuple containing the grid shape in (m, n, persistent_clusters).
- m: self.cluster_shape_m.
- n: self.cluster_shape_n.
- persistent_clusters: Number of persistent clusters that can run.
"""
# Total ctas in problem size
num_ctas_mnl = tuple(
x * y
for x, y in zip(
self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn
)
) + (self.problem_layout_ncluster_mnl.shape[2],)
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip)
# Total ctas that can run in one wave
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave)
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
return (*self.cluster_shape_mn, num_persistent_clusters)
class StaticPersistentTileScheduler:
"""A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.
:ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl
:type params: PersistentTileSchedulerParams
:ivar num_persistent_clusters: Number of persistent clusters that can be launched
:type num_persistent_clusters: Int32
:ivar cta_id_in_cluster: ID of the CTA within its cluster
:type cta_id_in_cluster: cute.Coord
:ivar _num_tiles_executed: Counter for executed tiles
:type _num_tiles_executed: Int32
:ivar _current_work_linear_idx: Current cluster index
:type _current_work_linear_idx: Int32
"""
def __init__(
self,
params: PersistentTileSchedulerParams,
num_persistent_clusters: Int32,
current_work_linear_idx: Int32,
cta_id_in_cluster: cute.Coord,
num_tiles_executed: Int32,
):
"""
Initializes the StaticPersistentTileScheduler with the given parameters.
:param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.
:type params: PersistentTileSchedulerParams
:param num_persistent_clusters: Number of persistent clusters that can be launched.
:type num_persistent_clusters: Int32
:param current_work_linear_idx: Current cluster index.
:type current_work_linear_idx: Int32
:param cta_id_in_cluster: ID of the CTA within its cluster.
:type cta_id_in_cluster: cute.Coord
:param num_tiles_executed: Counter for executed tiles.
:type num_tiles_executed: Int32
"""
self.params = params
self.num_persistent_clusters = num_persistent_clusters
self._current_work_linear_idx = current_work_linear_idx
self.cta_id_in_cluster = cta_id_in_cluster
self._num_tiles_executed = num_tiles_executed
def __extract_mlir_values__(self) -> list[ir.Value]:
values = extract_mlir_values(self.num_persistent_clusters)
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self.cta_id_in_cluster))
values.extend(extract_mlir_values(self._num_tiles_executed))
return values
def __new_from_mlir_values__(
self, values: list[ir.Value]
) -> "StaticPersistentTileScheduler":
assert len(values) == 6
new_num_persistent_clusters = new_from_mlir_values(
self.num_persistent_clusters, [values[0]]
)
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[1]]
)
new_cta_id_in_cluster = new_from_mlir_values(
self.cta_id_in_cluster, values[2:5]
)
new_num_tiles_executed = new_from_mlir_values(
self._num_tiles_executed, [values[5]]
)
return StaticPersistentTileScheduler(
self.params,
new_num_persistent_clusters,
new_current_work_linear_idx,
new_cta_id_in_cluster,
new_num_tiles_executed,
)
# called by host
@dsl_user_op
@staticmethod
def create(
params: PersistentTileSchedulerParams,
block_idx: Tuple[Integer, Integer, Integer],
grid_dim: Tuple[Integer, Integer, Integer],
*,
loc=None,
ip=None,
):
"""Initialize the static persistent tile scheduler.
:param params: Parameters for the persistent
tile scheduler.
:type params: PersistentTileSchedulerParams
:param block_idx: The 3d block index in the format (bidx, bidy, bidz).
:type block_idx: Tuple[Integer, Integer, Integer]
:param grid_dim: The 3d grid dimensions for kernel launch.
:type grid_dim: Tuple[Integer, Integer, Integer]
:return: A StaticPersistentTileScheduler object.
:rtype: StaticPersistentTileScheduler
"""
params = params
# Calculate the number of persistent clusters by dividing the total grid size
# by the number of CTAs per cluster
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
params.cluster_shape_mn, loc=loc, ip=ip
)
bidx, bidy, bidz = block_idx
# Initialize workload index equals to the cluster index in the grid
current_work_linear_idx = Int32(bidz)
# CTA id in the cluster
cta_id_in_cluster = (
Int32(bidx % params.cluster_shape_mn[0]),
Int32(bidy % params.cluster_shape_mn[1]),
Int32(0),
)
# Initialize number of tiles executed to zero
num_tiles_executed = Int32(0)
return StaticPersistentTileScheduler(
params,
num_persistent_clusters,
current_work_linear_idx,
cta_id_in_cluster,
num_tiles_executed,
)
# called by host
@staticmethod
def get_grid_shape(
params: PersistentTileSchedulerParams,
max_active_clusters: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Integer, Integer, Integer]:
"""Calculates the grid shape to be launched on GPU using problem shape,
threadblock shape, and active cluster size.
:param params: Parameters for grid shape calculation.
:type params: PersistentTileSchedulerParams
:param max_active_clusters: Maximum active clusters allowed.
:type max_active_clusters: Int32
:return: The calculated 3d grid shape.
:rtype: Tuple[Integer, Integer, Integer]
"""
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
# private method
def _get_current_work_for_linear_idx(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> WorkTileInfo:
"""Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.
:param current_work_linear_idx: The linear index of the current work.
:type current_work_linear_idx: Int32
:return: An object containing information about the current tile coordinates
and validity status.
:rtype: WorkTileInfo
"""
is_valid = current_work_linear_idx < cute.size(
self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip
)
cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord(
current_work_linear_idx, loc=loc, ip=ip
)
# cur_tile_coord is a tuple of i32 values
cur_tile_coord = tuple(
Int32(x) * Int32(z) + Int32(y)
for x, y, z in zip(
cur_cluster_coord,
self.cta_id_in_cluster,
(*self.params.cluster_shape_mn, Int32(1)),
)
)
return WorkTileInfo(cur_tile_coord, is_valid)
@dsl_user_op
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
return self._get_current_work_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
return self.get_current_work(loc=loc, ip=ip)
@dsl_user_op
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
self._current_work_linear_idx += Int32(advance_count) * Int32(
self.num_persistent_clusters
)
self._num_tiles_executed += Int32(1)
@property
def num_tiles_executed(self) -> Int32:
return self._num_tiles_executed

View File

@@ -0,0 +1,140 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass
from enum import Enum, auto
from typing import Tuple
from cutlass.cutlass_dsl import const_expr
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
import cutlass.cute as cute
class TensorMapUpdateMode(Enum):
"""
Enum class defining tensor map update modes.
Modes:
GMEM: Update tensormap in global memory
SMEM: Load tensormap from global memory to shared memory,
update it in shared memory, then store back to global memory
"""
GMEM = auto() # Update tensormap in global memory
SMEM = auto() # Update tensormap in shared memory
@dataclass(frozen=True)
class TensorMapManager:
"""
Manages TensorMap operations including initialization and updates.
Provides utilities to convert tensormap pointer to across different memory spaces.
"""
tensormap_update_mode: TensorMapUpdateMode
bytes_per_tensormap: int
# convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap.
# address_space: the address space of the resulting tensormap pointer. It could be generic or gmem
def get_tensormap_ptr(
self,
ptr: cute.Pointer,
address_space=_cute_ir.AddressSpace.gmem,
) -> cute.Pointer:
if address_space not in [
_cute_ir.AddressSpace.gmem,
_cute_ir.AddressSpace.generic,
]:
raise ValueError(f"Invalid address space: {address_space} for tensormap")
gmem_ptr_i64 = ptr.toint().ir_value()
gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get(
self.bytes_per_tensormap, gmem_ptr_i64.type.width
)
gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64)
gmem_ptr_ty = _cute_ir.PtrType.get(
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
address_space,
self.bytes_per_tensormap,
)
return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align)
# init tensormap pointed by dst_ptr with the one inside copy_atom.
# dst_ptr should be pointing to a global memory location or a smem location
# warp_id specifies which warp to perform the initialization
@cute.jit
def init_tensormap_from_atom(
self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, warp_id: int
) -> None:
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
if warp_idx == warp_id:
with cute.arch.elect_one():
cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
cute.arch.sync_warp()
return
# Perform a fence operation to ensure previous `init_tensormap_from_atom` calls have been completed
def fence_tensormap_initialization(
self,
) -> None:
if self.tensormap_update_mode == TensorMapUpdateMode.GMEM:
cute.arch.fence_acq_rel_cta()
return
# Perform a fence operation to ensure previous `update_tensormap` calls have been completed
def fence_tensormap_update(
self,
tensormap_ptr: cute.Pointer,
) -> None:
cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr)
return
@cute.jit
def update_tensormap(
self,
tensor_gmem: Tuple[cute.Tensor, ...],
tma_copy_atom: Tuple[cute.CopyAtom, ...],
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
warp_id: int,
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
) -> None:
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# updates before touching tensormap in global memory
if warp_idx == warp_id:
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
for copy_atom, tensor, smem_ptr in zip(
tma_copy_atom, tensor_gmem, tensormap_smem_ptr
):
cute.nvgpu.cpasync.update_tma_descriptor(
copy_atom, tensor, smem_ptr
)
# wait until it's safe to update tensormap in global memory
with cute.arch.elect_one():
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
cute.arch.sync_warp()
# updates to tensormap in global memory
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
else:
for copy_atom, tensor, gmem_ptr in zip(
tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
):
cute.nvgpu.cpasync.update_tma_descriptor(
copy_atom, tensor, gmem_ptr
)
cute.arch.sync_warp()
cute.nvgpu.cpasync.fence_tma_desc_release()

View File

@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from .cutlass import *
from ..base_dsl.ast_helpers import (
loop_selector,
if_selector,
if_executor,
while_selector,
while_executor,
range_constexpr,
range_dynamic,
const_expr,
dynamic_expr,
assert_executor,
bool_cast,
)
from ..base_dsl import *
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
from ..base_dsl.typing import _binary_op_type_promote
from ..base_dsl._mlir_helpers.gpu import *
from ..base_dsl._mlir_helpers.op import dsl_user_op
from ..base_dsl.runtime import *
from ..base_dsl.runtime import cuda as cuda_helpers
from ..base_dsl.compiler import compile
from ..base_dsl.runtime.dlpack_runtime import *
from ..base_dsl.runtime.jit_arg_adapters import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,515 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import List, Tuple
from cutlass._mlir import ir
from cutlass._mlir.dialects import scf, arith
from cutlass._mlir.extras import types as T
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
from ..base_dsl.ast_helpers import *
from ..base_dsl.utils.logger import log
from ..base_dsl import typing as t
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
from . import cutlass as cutlass_dsl
# =============================================================================
# AST Helpers
# =============================================================================
class LoopUnroll(ir.Attribute):
def __init__(self, **kwargs):
valid_keys = set(["count", "full"])
def to_mlir_attr(val):
if isinstance(val, bool):
return "true" if val else "false"
elif isinstance(val, int):
return f"{val} : i32"
else:
raise DSLNotImplemented(f"{type(val)} is not supported")
cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs}
if kwargs.get("count", None) == 1:
cfg["disable"] = "true"
unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">"
super().__init__(
ir.Attribute.parse(f"#llvm.loop_annotation<unroll = {unroll}>")
)
class ScfGenerator:
"""
Encapsulates common scf dialect functionality: pack, unpack, and SCF execution.
"""
def __init__(self):
pass
@staticmethod
def fill_none(ir_values, unpacked_values):
i = 0
for idx, item in enumerate(unpacked_values):
if item is not None:
unpacked_values[idx] = ir_values[i]
i += 1
@staticmethod
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
"""
Convert region_result to a list if it is not already a list
If region_result is a list, return it as is.
If region_result is None, return an empty list.
If region_result is not a list, return a list containing region_result as the only element.
"""
if region_result is None:
region_result_list = []
elif not isinstance(region_result, list):
region_result_list = [region_result]
else:
region_result_list = region_result
return region_result_list
@staticmethod
def check_region_result(region_values, ir_values):
for i, (expected_value, actual_value) in enumerate(
zip(ir_values, region_values)
):
expected_value_type = get_mlir_types(expected_value)
actual_value_type = get_mlir_types(actual_value)
if expected_value_type != actual_value_type:
return False, i, expected_value_type, actual_value_type
return True, -1, None, None
def scf_execute_dynamic(
self,
op_type_name: str,
used_args: List[Any],
mix_iter_args: List[Any],
mix_iter_arg_names: List[str],
create_op_func: Callable[
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
],
region_builders: List[
Callable[
[
"ir.Operation",
List["ir.Value"], # block_args
List[Any], # used_args
List["ir.Value"], # dyn_yield_ops
Dict[int, Tuple[int, int]],
List[Any],
],
Any,
]
],
# block_term_op_builder[region_builder] = scf_op_builder
# e.g. scf.ConditionOp for while loop
block_term_op_builder: Dict[Callable, Callable] = {},
) -> Any:
# 1) Unpack
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
)
# 2) Create the SCF op
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
# 3) Build the regions
for i, builder in enumerate(region_builders):
region = op.regions[i]
block = region.blocks[0]
with ir.InsertionPoint(block):
block_args = list(block.arguments)
region_result = builder(
op,
block_args,
used_args,
dyn_unpacked_values,
dyn_indices,
dyn_class_types,
)
# Use custom terminator if provided for this builder, otherwise use default YieldOp
if builder in block_term_op_builder:
# Use the provided terminator generator
block_term_op_builder[builder](region_result)
else:
# Normalize region_result
region_result_list = ScfGenerator._normalize_region_result_to_list(
region_result
)
# Default behavior - generate YieldOp
region_values, unpacked_values, _, _ = (
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
)
is_match, mismatch_idx, expected_type, actual_type = (
ScfGenerator.check_region_result(region_values, ir_values)
)
if not is_match:
# From unpacked index, we need to find the original index
original_idx = -1
for unpacked_idx, (original_idx, length) in dyn_indices.items():
if (
mismatch_idx >= original_idx
and mismatch_idx < original_idx + length
):
original_idx = unpacked_idx
break
raise DSLRuntimeError(
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
)
scf.YieldOp(region_values)
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
ScfGenerator.fill_none(op.results, unpacked_values)
# 4) Pack final results
final_results = cutlass_dsl.pack_from_irvalue(
unpacked_values, dyn_indices, dyn_class_types
)
# 5) Return in a nice pattern
if not final_results:
return
if len(final_results) == 1:
return final_results[0]
return final_results
def _loop_execute_range_dynamic(
func: Callable,
start: Any,
stop: Any,
step: Any,
used_args: List[Any] = [],
mix_iter_args: List[Any] = [],
mix_iter_arg_names: List[str] = [],
unroll: int = -1,
unroll_full: bool = False,
):
"""
Example: build an scf.for with optional unroll, using our universal helper.
"""
scf_gen = ScfGenerator()
def create_for_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
for d in dyn_yield_ops:
if not isinstance(d, ir.Value):
raise DSLRuntimeError(
f"Invalid dyn_yield_ops: {dyn_yield_ops} \n\tExpected ir.Value, got {type(d)}"
)
# Convert Python ints or values to IR constants if needed
start_ = t.as_numeric(start)
stop_ = t.as_numeric(stop)
step_ = t.as_numeric(step)
assert start_ is not t.Int32, "Start is required for scf.for"
assert stop_ is not t.Int32, "Stop is required for scf.for"
assert step_ is not t.Int32, "Step is required for scf.for"
start_ = start_.ir_value()
stop_ = stop_.ir_value()
step_ = step_.ir_value()
# Possibly attach unroll attributes
unroll_attr = None
if unroll_full:
unroll_attr = LoopUnroll(full=True)
elif unroll != -1:
unroll_attr = LoopUnroll(count=unroll)
log().debug("Unroll attribute: %s", unroll_attr)
log().debug(
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
start_,
type(start_),
stop_,
type(stop_),
step_,
type(step_),
)
# Create scf.ForOp, passing iteration args if any
try:
if not dyn_yield_ops:
for_op = scf.ForOp(start_, stop_, step_)
else:
for_op = scf.ForOp(start_, stop_, step_, list(dyn_yield_ops))
except Exception as e:
yield_ops = "\n".join(
f"\t\t{i} => {d} : type : {type(d)}"
for i, d in enumerate(dyn_yield_ops)
)
raise DSLRuntimeError(
f"Failed to create scf.ForOp \n\t\tstart={start_}: type : {type(start_)}"
f"\n\t\tstop={stop_}: type : {type(stop_)}\n\t\tstep={step_}: type : {type(step_)}"
f", \n\tdyn_yield_ops:\n{yield_ops}"
) from e
if unroll_attr is not None:
for_op.attributes["loop_annotation"] = unroll_attr
return for_op
def for_body_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
):
# Insert induction variable at the beginning
dyn_yield_ops.insert(0, block_args[0])
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
# scf.ForOp block_args are typically [induction_var, iter_args...]
# But MLIR also gives you op.induction_variable
iv = t.as_numeric(op.induction_variable)
log().debug(
"For body builder: %s block_args: %s used_args: %s",
iv,
block_args,
used_args,
)
if len(block_args) <= 1:
# No iteration arguments, or only the induction var
func(iv, *used_args)
return [] # yield nothing
else:
# block_args[1:] are iteration variables
func_args = [*used_args]
func_args.extend(
cutlass_dsl.pack_from_irvalue(
block_args[1:], dyn_indices, dyn_class_types
)
)
updated_func_args = func(iv, *func_args)
return updated_func_args
# Now call the universal SCF executor with a single region builder
return scf_gen.scf_execute_dynamic(
op_type_name="for",
used_args=used_args,
mix_iter_args=mix_iter_args,
mix_iter_arg_names=mix_iter_arg_names,
create_op_func=create_for_op,
region_builders=[for_body_builder],
)
def _if_execute_dynamic(
pred: "ir.Value",
then_block: Callable,
else_block: Callable = None,
used_args: List[Any] = [],
mix_yield_args: List[Any] = [],
mix_yield_arg_names: List[str] = [],
if_constexpr=None, # ignoring for brevity
):
"""
Build an scf.if with optional else, using our universal helper.
"""
scf_gen = ScfGenerator()
def create_if_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
# Assume final result types match the dynamic yields
result_types = [arg.type for arg in dyn_yield_ops]
pred_ = t.as_numeric(pred)
if not isinstance(pred_, Boolean):
# Convert to Boolean through comparison
pred_ = pred_ == True
try:
if_op = scf.IfOp(
pred_.ir_value(),
hasElse=(else_block is not None),
results_=result_types,
)
except Exception as e:
raise DSLRuntimeError(
f"Failed to create scf.IfOp \n\t\tpred={pred_}: type : {type(pred_)}"
) from e
return if_op
def then_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
):
flat_args = [*used_args]
flat_args.extend(
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
)
return then_block(*flat_args)
region_builders = [then_builder]
if else_block is not None:
def else_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
):
flat_args = [*used_args]
flat_args.extend(
cutlass_dsl.pack_from_irvalue(
dyn_yield_ops, dyn_indices, dyn_class_types
)
)
return else_block(*flat_args)
region_builders.append(else_builder)
return scf_gen.scf_execute_dynamic(
op_type_name="if",
used_args=used_args,
mix_iter_args=mix_yield_args,
mix_iter_arg_names=mix_yield_arg_names,
create_op_func=create_if_op,
region_builders=region_builders,
)
def _while_execute_dynamic(
while_before_block: Callable,
while_after_block: Callable = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
):
"""
Create and return an SCF WhileOp for dynamic loops.
Generate the dynamic loop body using SCF WhileOp.
Args:
while_before_block: Function that returns (condition, updated_values)
while_after_block: Function that returns updated values
used_args: Additional arguments used in the loop body
yield_args: Values that are updated in the loop
See create_while_function in ast_preprocessor.py for details on the input structure.
"""
log().debug("_while_execute_dynamic")
while_op_type_name = "while"
scf_gen = ScfGenerator()
def create_while_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
# Create the while operation with the types from yield_args
result_types = [arg.type for arg in dyn_yield_ops]
try:
while_op = scf.WhileOp(result_types, dyn_yield_ops)
while_op.before.blocks.append(*result_types)
while_op.after.blocks.append(*result_types)
log().debug("[%s]", while_op)
return while_op
except Exception as e:
yield_ops = "\n".join(
f"\t\t{i} => {d} : type : {type(d)}"
for i, d in enumerate(dyn_yield_ops)
)
raise DSLRuntimeError(
f"Failed to create scf.WhileOp with yield_ops:\n{yield_ops}"
) from e
def before_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
):
# Build the before (condition) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
)
log().debug("before block args: %s", flat_args)
cond, before_results = while_before_block(*flat_args)
if not isinstance(before_results, (list, ir.OpResultList)):
before_results = [before_results]
log().debug("cond [%s]", cond)
log().debug(
"before_results [%s]",
before_results,
)
return cond, before_results
def before_block_terminator(cond_and_results):
# Generate a condition op instead of yield op
cond = cond_and_results[0]
before_result_list = ScfGenerator._normalize_region_result_to_list(
cond_and_results[1]
)
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
[cond], while_op_type_name
)
ir_cond = ir_cond_list[0]
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
before_result_list, while_op_type_name
)
log().debug(
"creating scf.ConditionOp with [%s], [%s]",
ir_cond,
ir_results_list,
)
scf.ConditionOp(ir_cond, ir_results_list)
def after_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
):
# Build the after (body) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
)
log().debug("after block args: %s", flat_args)
after_results = while_after_block(*flat_args)
if not isinstance(after_results, (list, ir.OpResultList)):
after_results = [after_results]
log().debug(
"after_results [%s]",
after_results,
)
return after_results
# Call the universal SCF executor with two region builders
return scf_gen.scf_execute_dynamic(
op_type_name=while_op_type_name,
used_args=used_args,
mix_iter_args=yield_args,
mix_iter_arg_names=yield_arg_names,
create_op_func=create_while_op,
region_builders=[before_block_builder, after_block_builder],
block_term_op_builder={
before_block_builder: before_block_terminator
}, # Only customize the before block
)

View File

@@ -0,0 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl=4.0.0.dev1

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.9.2'
this.__version__ = '4.0.0'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch

View File

@@ -111,6 +111,7 @@
args.sync()
"""
from __future__ import annotations
from typing import Optional
from cutlass.utils.lazy_import import lazy_import

View File

@@ -1,3 +1,34 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import importlib
from typing import Any
@@ -8,4 +39,3 @@ def lazy_import(mod_name: str) -> Any:
return getattr(module, name)
return Lazy()

View File

@@ -193,3 +193,4 @@ class CUDAEventProfiler:
flops_ += m * n * batch_count * 2
return flops_

View File

@@ -75,15 +75,10 @@ audit_csv_runtime_fields = [
]
def hash_cutlass_string(input_string):
# Regex pattern to match instruction shape
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
output = re.sub(instruction_shape_pattern, "", input_string)
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
output = re.sub(mma_cluster_shape_pattern, "", output)
output = re.sub(mma_cluster_shape_pattern, "", input_string)
return output
@@ -288,7 +283,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "101a", "120a"])
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
@@ -300,23 +295,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
#
sm100_mma_data_type_general = [
'x16gemm_f16_f16_f16_f16_f16',
'x16gemm_f16_f16_f16_void_f16',
'x16gemm_f16_f16_f32_f16_f16',
'x8tf32gemm_f32_f32_f32_f32_f32',
'x16bf16gemm_f32_f32_f32_f32_f32',
'gemm_f16_f16_f16_f16_f16',
'gemm_f16_f16_f16_void_f16',
'gemm_f16_f16_f32_f16_f16',
'tf32gemm_f32_f32_f32_f32_f32',
'bf16gemm_f32_f32_f32_f32_f32',
]
sm100_mma_data_type_runtime_dtype = [
'x32gemm_f4_f4_f32_f32_f32',
'x32gemm_f6_f6_f32_f32_f32',
'x32gemm_f8_f8_f32_f32_f32',
'gemm_f4_f4_f32_f32_f32',
'gemm_f6_f6_f32_f32_f32',
'gemm_f8_f8_f32_f32_f32',
]
sm100_mma_data_type_mergeable = [
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'x32gemm_e2m1_e2m1_f32_f32_f32',
'x32gemm_e3m2_e3m2_f32_f32_f32',
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
'gemm_e2m1_e2m1_f32_f32_f32',
'gemm_e3m2_e3m2_f32_f32_f32',
]
sm100_mma_cluster_size = [
@@ -331,22 +326,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'ntn'
]
sm100_mma_instruction_shape = [
# [0] .1CTA, General
['64x128', '128x128', '128x256'],
# [1] .2CTA, General
['128x128', '256x128', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
#
# Block Scale Gemm
@@ -354,19 +342,19 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_data_type_base = [
# runtime datatypes
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type_mergeable = [
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
@@ -377,56 +365,43 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch == "100a":
if arch == "100a" or arch == "100f":
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "101a":
elif arch == "101a" or arch == "101f":
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch == "120a":
elif arch == "120a" or arch == "120f":
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
[ '128x128' ]
]
# sm120 MMA instruction shapes
blockscaled_sm120_mma_instruction_shapes = [
[ 's16x8x64gemm',
's16x8x32gemm'
]
]
# Restrict to two layouts to reduce L0 build and test time.
blockscaled_sm120_mma_layouts = [ 'tn' ]
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
raise Exception(error_message)
# Statically encoded kernels are still added to generated_kernels
@@ -445,14 +420,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
]
# Restrict to two layouts to reduce L1 build and test time.
sm100_mma_layouts = ['tnt', 'ntn']
sm100_mma_instruction_shape = [
# .1CTA
['64x128', '128x128', '128x256'],
# .2CTA
['128x128', '256x128', '256x256']
]
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
@@ -463,15 +432,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
block_scaled_layouts = ['tnt']
block_scaled_instruction_shape = [
# .1CTA
['128x128', '128x192', '128x256'],
# .2CTA
['256x128', '256x192', '256x256'],
]
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \

View File

@@ -183,10 +183,7 @@ class GemmOperation:
math_op = self.tile_description.math_instruction.math_operation
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
if self.is_3x:
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
else:
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
inst_shape += math_op_string
@@ -194,7 +191,9 @@ class GemmOperation:
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
short_math_name = self.short_math_name() if not self.is_3x else ""
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
# Generates a string representing the MMA instruction.
def extended_name(self):
@@ -337,18 +336,36 @@ class GemmOperation:
def opcode_class_name(self):
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
def get_collective_tile_shape(self):
"""
Get the tile shape passed to the collective builder.
On Blackwell, this is different than the operation.tile_description.tile_shape.
"""
is_sm100_kernel = (self.arch == 100)
if not is_sm100_kernel:
return self.tile_description.tile_shape
opcode_class_main = self.tile_description.math_instruction.opcode_class
instruction_shape = self.tile_description.math_instruction.instruction_shape
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
return (tile_shape_m, tile_shape_n, tile_shape_k)
# Generates the full kernel function name
def procedural_name(self):
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.arch >= 90:
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
tile_shape = self.get_collective_tile_shape()
return kernel_name_template.format(
p = self.prefix,
ar = self.arch,
op = opcode_class_name,
ex = self.extended_name_3x(),
ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "",
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
l = self.tile_description.stages,
s = self.layout_name_3x(),
@@ -920,28 +937,8 @@ ${compile_guard_end}
instruction_shape = operation.tile_description.math_instruction.instruction_shape
cluster_m = operation.tile_description.cluster_shape[0]
cluster_n = operation.tile_description.cluster_shape[1]
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
# account for static/dynamic cluster shapes
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
# Shape passed to epilogue builder
is_sm100_kernel = (operation.arch == 100)
if is_sm100_kernel:
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
if cluster_m <= 0:
cta_m = cta_m // cta_m_per_mma_instruction
if opcode_class_main in [OpcodeClass.TensorOp
, OpcodeClass.BlockScaledTensorOp
, OpcodeClass.SparseTensorOp
]:
tile_shape_m = instruction_shape[0]
tile_shape_n = instruction_shape[1]
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
# stage count set to zero indicates builder automatic stage selection
if operation.tile_description.stages > 0:

View File

@@ -1003,14 +1003,11 @@ class ConvOperation3x:
math_op = self.tile_description.math_instruction.math_operation
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
inst_shape += math_op_string
if self.tile_description.math_instruction.element_a != self.A.element and \
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, ConvKindNames[self.conv_kind])
return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind])
def extended_name(self):
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
@@ -5997,8 +5994,8 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc)
valid_types_for_d = [DataType.f32]
valid_types_for_c = [DataType.f32]
valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2]
valid_types_for_c = copy.deepcopy(valid_types_for_d)
tile_descriptions = generate_tile_descriptions_sm90(
math_instructions=math_instructions,
@@ -6009,6 +6006,12 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
math_inst = tile_desc.math_instruction
data_types = []
# Limit C/D types to avoid a giant number of instantiations.
# A typical use case for mixed dtype in DL is weight quantization (tensor A),
# therefore we can limit the output type to that of activation (tensor B).
valid_types_for_c = [math_inst.element_b]
valid_types_for_d = [math_inst.element_b]
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
data_types.append(
generate_data_types_from_math_instruction(
@@ -6791,6 +6794,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.Default
]
@@ -6838,6 +6846,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
@@ -6937,6 +6950,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.Default
]
@@ -7090,6 +7108,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -7247,6 +7270,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -7456,6 +7484,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -7916,6 +7949,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[2,1,1],
[1,1,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
@@ -7985,6 +8025,12 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -8138,6 +8184,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[1,1,1],
[2,1,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
@@ -8211,6 +8264,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1],
[4,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -8417,6 +8477,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[1,1,1],
[2,1,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
@@ -8537,6 +8604,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1],
[4,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -8689,6 +8763,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -8788,6 +8867,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -8925,6 +9009,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
@@ -8953,6 +9040,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
@@ -9044,6 +9134,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
@@ -9072,6 +9165,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
@@ -9163,6 +9259,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
@@ -9191,6 +9290,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
@@ -9287,6 +9389,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
@@ -9319,6 +9424,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
@@ -9417,6 +9525,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_1sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
tile_descriptions.append(
TileDescription([
@@ -9476,6 +9587,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in sm100_cluster_shape_2sm:
if 101 in manifest.compute_capabilities :
if cluster_shape == [4,4,1] :
continue
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
@@ -9578,6 +9692,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[1,2,1], [1,1,1], [1,4,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.StreamK,
]
@@ -9612,6 +9732,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -9658,6 +9784,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[1,2,1], [1,1,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.StreamK
]
@@ -9726,6 +9858,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -9809,6 +9947,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [
[1,2,1], [2,1,1], [1,1,1]
, DynamicClusterShape
]
tile_schedulers = [
TileSchedulerType.StreamK,
]
@@ -9861,6 +10005,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
, DynamicClusterShape
]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
@@ -9960,6 +10110,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
# tile_descriptions is a 2-level list.
# Each inner list is for each cluster shape.
for math_inst, output_type in math_instructions_w_output_1sm:
@@ -10023,6 +10176,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_2sm)
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
for math_inst, output_type in math_instructions_w_output_2sm:
tile_descriptions = []
@@ -10103,6 +10258,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_1sm)
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
if 101 in manifest.compute_capabilities :
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
for math_inst, output_type in math_instructions_w_output_1sm:
tile_descriptions = []
@@ -10166,6 +10323,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
data_types_and_instruction_shapes_2sm)
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
if 101 in manifest.compute_capabilities :
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
for math_inst, output_type in math_instructions_w_output_2sm:
tile_descriptions = []
@@ -10629,6 +10788,8 @@ def GenerateSM100(manifest, cuda_version):
#
# Dense Gemm
#
architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',]
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version)
GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version)
@@ -10636,7 +10797,8 @@ def GenerateSM100(manifest, cuda_version):
GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version)
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
if '100f' not in architectures and '101f' not in architectures:
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
# grouped GEMM
@@ -10657,7 +10819,8 @@ def GenerateSM100(manifest, cuda_version):
#
GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
if '100f' not in architectures and '101f' not in architectures:
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version)
GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
@@ -11166,7 +11329,7 @@ if __name__ == "__main__":
GenerateSM89(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
blackwell_enabled_arch = any(arch in ["100a", "101a", "120a"] for arch in archs)
blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs)
if blackwell_enabled_arch:
GenerateSM100(manifest, args.cuda_version)
GenerateSM120(manifest, args.cuda_version)

View File

@@ -523,10 +523,14 @@ class Manifest:
arch_conditional_cc = [
'90a',
'100a',
'100f',
'101a',
'120a'
'101f',
'120a',
'120f'
]
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]
architectures = [x if x not in arch_conditional_cc else x.split('f')[0] for x in architectures]
self.compute_capabilities = [int(x) for x in architectures]

View File

@@ -375,6 +375,13 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
# generator can stamp out duplicate kernels, because it doesn't explicitly set instruction
# shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using
# the auto kernel schedule.
math_inst_stub = copy.deepcopy(math_inst)
math_inst_stub.instruction_shape = [0, 0, 0]
tile_desc = TileDescription(
threadblock_shape=[
math_inst.instruction_shape[0] * mma_mul[0],
@@ -383,7 +390,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
],
stages=0,
warp_count=[4, 1, 1],
math_instruction=math_inst,
math_instruction=math_inst_stub,
min_compute=90,
max_compute=90,
cluster_shape=cluster_size)
@@ -551,6 +558,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
b_type_size = DataTypeSize[data_types["b_type"]]
if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
schedules = []
stream_k_schedules = []
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
if a_type_size > b_type_size:
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
@@ -579,7 +587,11 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
return schedules, []
stream_k_schedules.append([
KernelScheduleType.TmaWarpSpecializedCooperative,
epilogue_schedule
])
return schedules, stream_k_schedules
if not is_aligned and not is_blockwise(gemm_kind):
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.9.2',
version='4.0.0',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.9.2',
version='4.0.0',
description='Python implementation of CuTe',
packages=['pycute'],
)