mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
188
python/CuTeDSL/EULA.txt
Normal file
188
python/CuTeDSL/EULA.txt
Normal file
@@ -0,0 +1,188 @@
|
||||
NVIDIA Software License Agreement
|
||||
|
||||
IMPORTANT NOTICE – PLEASE READ AND AGREE BEFORE USING THE SOFTWARE
|
||||
This software license agreement (“Agreement”) is a legal agreement between you, whether an individual or entity, (“you”) and NVIDIA Corporation (“NVIDIA”) and governs the use of the NVIDIA CUTLASS DSLs software and materials that NVIDIA delivers to you under this Agreement (“Software”).
|
||||
NVIDIA and you are each a “party” and collectively the “parties.”
|
||||
This Agreement can be accepted only by an adult of legal age of majority in the country in which the Software is used.
|
||||
If you don’t have the required age or authority to accept this Agreement, or if you don’t accept all the terms and conditions of this Agreement, do not use the Software.
|
||||
|
||||
1. License Grants
|
||||
|
||||
1.1. License Grant to You. The Software made available by NVIDIA to you is licensed, not sold.
|
||||
Subject to the terms of this Agreement, NVIDIA grants you a limited, non-exclusive, revocable, non-transferable, and non-sublicensable (except as expressly granted in this Agreement), license to:
|
||||
|
||||
a. install and use copies of the Software,
|
||||
b. configure the Software using configuration files provided (if applicable),
|
||||
c. modify and create derivative works of any sample or example source code NVIDIA delivers to you as part of the Software (“Derivatives”) (if applicable), and
|
||||
d. distribute python files in the Software package in source format as incorporated into a software application subject to the following distribution requirements:
|
||||
|
||||
i. Your application must have material additional functionality, beyond the included portions of the Software.
|
||||
ii. The distributable portions of the Software shall only be accessed by your application.
|
||||
iii. The following notice shall be included in modifications and derivative works of sample source code distributed: “This software contains source code provided by NVIDIA Corporation.”
|
||||
iv. Unless a developer tool is identified in this Agreement as distributable, it is delivered for your internal use only.
|
||||
v. The terms under which you distribute your application must be consistent with the terms of this Agreement, including (without limitation) terms relating to the license grant and license restrictions and protection of NVIDIA’s intellectual property rights.
|
||||
vi. Additionally, you agree that you will protect the privacy, security and legal rights of your application users.
|
||||
|
||||
The foregoing (a) through (d) are, collectively, the “Purpose”, and the developed applications are only for use in systems with NVIDIA GPUs.
|
||||
|
||||
1.2. License Grant to NVIDIA. Subject to the terms of this Agreement, you grant NVIDIA and its affiliates a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit at NVIDIA’s discretion any Derivatives created by or for you.
|
||||
You may, but are not required to, deliver any Derivatives to NVIDIA.
|
||||
|
||||
2. License Restrictions
|
||||
|
||||
Your license to use the Software and Derivatives is restricted as stated in this Section 2 (“License Restrictions”).
|
||||
You will cooperate with NVIDIA and, upon NVIDIA’s written request, you will confirm in writing and provide reasonably requested information to verify your compliance with the terms of this Agreement.
|
||||
You may not:
|
||||
|
||||
2.1. Use the Software or Derivatives for any purpose other than the Purpose;
|
||||
|
||||
2.2. Sell, rent, sublicense, transfer, distribute or otherwise make available to others (except authorized users as stated in Section 3 (“Authorized Users”)) any portion of the Software or Derivatives, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.3. Reverse engineer, decompile, or disassemble the Software components provided in binary form, nor attempt in any other manner to obtain source code of such Software;
|
||||
|
||||
2.4. Modify or create derivative works of the Software, except as expressly granted in Section 1.1 (“License Grant to You”);
|
||||
|
||||
2.5. Change or remove copyright or other proprietary notices in the Software;
|
||||
|
||||
2.6. Bypass, disable, or circumvent any technical limitation, encryption, security, digital rights management or authentication mechanism in the Software;
|
||||
|
||||
2.7. Use the Software or Derivatives in any manner that would cause them to become subject to an open source software license, subject to the terms in Section 6 (“Components Under Other Licenses”);
|
||||
|
||||
2.8. Use the Software or Derivatives in violation of any applicable law or regulation in relevant jurisdictions
|
||||
|
||||
2.9. Indicate that a product or service developed with the Software or Derivatives is sponsored or endorsed by NVIDIA;
|
||||
|
||||
2.10. Replace any NVIDIA software components in the Software that are governed by this Agreement with other software that implements NVIDIA APIs;
|
||||
|
||||
2.11. Reverse engineer, decompile or disassemble any portion of the output generated using Software elements for the purpose of translating such output artifacts to target a non-NVIDIA platform; or
|
||||
|
||||
3. Authorized Users
|
||||
|
||||
You may allow employees and contractors of your entity or of your subsidiary(ies), and for educational institutions also enrolled students, to internally access and use the Software as authorized by this Agreement from your secure network to perform the work authorized by this Agreement on your behalf.
|
||||
You are responsible for the compliance with the terms of this Agreement by your authorized users.
|
||||
Any act or omission that if committed by you would constitute a breach of this Agreement will be deemed to constitute a breach of this Agreement if committed by your authorized users.
|
||||
|
||||
4. Pre-Release
|
||||
|
||||
Software versions identified as alpha, beta, preview, early access or otherwise as pre-release (“Pre-Release”) may not be fully functional, may contain errors or design flaws, and may have reduced or different security, privacy, availability and reliability standards relative to NVIDIA commercial offerings.
|
||||
You use Pre-Release Software at your own risk. NVIDIA did not design or test the Software for use in production or business-critical systems.
|
||||
NVIDIA may choose not to make available a commercial version of Pre-Release Software.
|
||||
NVIDIA may also choose to abandon development and terminate the availability of Pre-Release Software at any time without liability.
|
||||
|
||||
5. Updates
|
||||
|
||||
NVIDIA may at any time and at its option, change, discontinue, or deprecate any part, or all, of the Software, or change or remove features or functionality, or make available patches, workarounds or other updates to the Software.
|
||||
Unless the updates are provided with their separate governing terms, they are deemed part of the Software licensed to you under this Agreement, and your continued use of the Software is deemed acceptance of such changes.
|
||||
|
||||
6. Components Under Other Licenses
|
||||
|
||||
The Software may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms (“Other Licenses”).
|
||||
The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights;
|
||||
except that this Agreement will prevail regarding the use of third-party open source software, unless a third-party open source software license requires its license terms to prevail.
|
||||
Open source software license means any software, data or documentation subject to any license identified as an open source license by the Open Source Initiative (http://opensource.org), Free Software Foundation (http://www.fsf.org) or other similar open source organization or listed by the Software Package Data Exchange (SPDX) Workgroup under the Linux Foundation (http://www.spdx.org).
|
||||
|
||||
7. Ownership
|
||||
|
||||
7.1. NVIDIA Ownership. The Software, including all intellectual property rights, is and will remain the sole and exclusive property of NVIDIA or its licensors.
|
||||
Except as expressly granted in this Agreement, (a) NVIDIA reserves all rights, interests and remedies in connection with the Software, and (b) no other license or right is granted to you by implication, estoppel or otherwise.
|
||||
|
||||
7.2. Your Ownership. Subject to the rights of NVIDIA and its suppliers in the Software, which continue to be licensed as stated in this Agreement, even when incorporated in your products or services, and the extent permitted by applicable law, as between you and NVIDIA, you hold all rights, title and interest in and to your products, services and Derivatives you develop as permitted in this Agreement including their respective intellectual property rights.
|
||||
|
||||
8. Feedback
|
||||
|
||||
You may, but you are not obligated to, provide suggestions, requests, fixes, modifications, enhancements, or other feedback regarding the Software (collectively, “Feedback”).
|
||||
Feedback, even if designated as confidential by you, will not create any confidentiality obligation for NVIDIA or its affiliates.
|
||||
If you provide Feedback, you grant NVIDIA, its affiliates and its designees a non-exclusive, perpetual, irrevocable, sublicensable, worldwide, royalty-free, fully paid-up and transferable license, under your intellectual property rights, to publicly perform, publicly display, reproduce, use, make, have made, sell, offer for sale, distribute (through multiple tiers of distribution), import, create derivative works of and otherwise commercialize and exploit the Feedback at NVIDIA’s discretion.
|
||||
|
||||
9. Termination
|
||||
|
||||
9.1. Termination. This Agreement will automatically terminate without notice from NVIDIA if you fail to comply with any of the terms in this Agreement or if you commence or participate in any legal proceeding against NVIDIA with respect to the Software.
|
||||
Additionally, either party may terminate this Agreement at any time with thirty (30) days’ advance written notice to the other party.
|
||||
|
||||
9.2. Effect of Termination. Upon any expiration or termination of this Agreement, you will promptly (a) stop using and return, delete or destroy NVIDIA confidential information and all Software received under this Agreement, and (b) delete or destroy Derivatives created under this Agreement, unless an authorized NVIDIA representative provides prior written approval that you may keep a copy of the Derivatives solely for archival purposes.
|
||||
Upon written request, you will certify in writing that you have complied with your obligations under this Section 9.2 (“Effect of Termination”).
|
||||
|
||||
9.3. Survival. Section 1.2 (“License Grant to NVIDIA”), Section 5 (“Updates”), Section 6 (“Components Under Other Licenses”), Section 7 (“Ownership”), Section 8 (“Feedback), Section 9.2 (“Effect of Termination”), Section 9.3 (“Survival”), Section 10 (“Disclaimer of Warranties”), Section 11 (“Limitation of Liability”), Section 12 (“Use in Mission Critical Applications”), Section 13 (“Governing Law and Jurisdiction”), Section 14 (“Indemnity”) and Section 15 (“General”) will survive any expiration or termination of this Agreement.
|
||||
|
||||
10. Disclaimer of Warranties
|
||||
|
||||
THE SOFTWARE IS PROVIDED BY NVIDIA AS-IS AND WITH ALL FAULTS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA DISCLAIMS ALL WARRANTIES AND REPRESENTATIONS OF ANY KIND, WHETHER
|
||||
EXPRESS, IMPLIED OR STATUTORY, RELATING TO OR ARISING UNDER THIS AGREEMENT, INCLUDING, WITHOUT LIMITATION, THE WARRANTIES OF TITLE, NONINFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, USAGE OF TRADE AND COURSE OF DEALING. NVIDIA DOES NOT WARRANT OR ASSUME RESPONSIBILITY FOR THE ACCURACY OR COMPLETENESS OF ANY THIRD-PARTY INFORMATION, TEXT, GRAPHICS, LINKS CONTAINED IN THE SOFTWARE.
|
||||
WITHOUT LIMITING THE FOREGOING, NVIDIA DOES NOT WARRANT THAT THE SOFTWARE WILL MEET YOUR REQUIREMENTS, ANY DEFECTS OR ERRORS WILL BE CORRECTED, ANY CERTAIN CONTENT WILL BE AVAILABLE; OR THAT THE SOFTWARE IS FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS. NO INFORMATION OR ADVICE GIVEN BY NVIDIA WILL IN ANY WAY INCREASE THE SCOPE OF ANY WARRANTY EXPRESSLY PROVIDED IN THIS AGREEMENT.
|
||||
NVIDIA does not warrant or assume responsibility for the accuracy or completeness of any third-party information, text, graphics or links contained in the Software.
|
||||
|
||||
11. Limitations of Liability
|
||||
|
||||
11.1. EXCLUSIONS. TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL NVIDIA BE LIABLE FOR ANY (I) INDIRECT, PUNITIVE, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES, OR (ii) DAMAGES FOR (a) THE COST OF PROCURING SUBSTITUTE GOODS, OR (b) LOSS OF PROFITS, REVENUES, USE, DATA OR GOODWILL ARISING OUT OF OR RELATED TO THIS AGREEMENT, WHETHER BASED ON BREACH OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY, OR OTHERWISE, AND EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND EVEN IF A PARTY’S REMEDIES FAIL THEIR ESSENTIAL PURPOSE.
|
||||
|
||||
11.2. DAMAGES CAP. ADDITIONALLY, TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, NVIDIA’S TOTAL CUMULATIVE AGGREGATE LIABILITY FOR ANY AND ALL LIABILITIES, OBLIGATIONS OR CLAIMS ARISING OUT OF OR RELATED TO THIS AGREEMENT WILL NOT EXCEED FIVE U.S. DOLLARS (US$5).
|
||||
|
||||
12. Use in Mission Critical Applications
|
||||
|
||||
You acknowledge that the Software provided under this Agreement is not designed or tested by NVIDIA for use in any system or application where the use or failure of such system or application developed with NVIDIA’s Software could result in injury, death or catastrophic damage (each, a “Mission Critical Application”).
|
||||
Examples of Mission Critical Applications include use in avionics, navigation, autonomous vehicle applications, AI solutions for automotive products, military, medical, life support or other mission-critical or life-critical applications.
|
||||
NVIDIA will not be liable to you or any third party, in whole or in part, for any claims or damages arising from these uses.
|
||||
You are solely responsible for ensuring that systems and applications developed with the Software include sufficient safety and redundancy features and comply with all applicable legal and regulatory standards and requirements.
|
||||
|
||||
13. Governing Law and Jurisdiction
|
||||
|
||||
This Agreement will be governed in all respects by the laws of the United States and the laws of the State of Delaware, without regard to conflict of laws principles or the United Nations Convention on Contracts for the International Sale of Goods.
|
||||
The state and federal courts residing in Santa Clara County, California will have exclusive jurisdiction over any dispute or claim arising out of or related to this Agreement, and the parties irrevocably consent to personal jurisdiction and venue in those courts;
|
||||
except that either party may apply for injunctive remedies or an equivalent type of urgent legal relief in any jurisdiction.
|
||||
|
||||
14. Indemnity
|
||||
|
||||
By using the Software you agree to defend, indemnify and hold harmless NVIDIA and its affiliates and their respective officers, directors, employees and agents from and against any claims, disputes, demands, liabilities, damages, losses, costs and expenses arising out of or in any way connected with (i) products or services that have been developed or deployed with or use the Software, or claims that they violate laws, or infringe, violate, or misappropriate any third party right;
|
||||
or (ii) use of the Software in breach of the terms of this Agreement.
|
||||
|
||||
15. General
|
||||
|
||||
15.1. Independent Contractors.
|
||||
The parties are independent contractors, and this Agreement does not create a joint venture, partnership, agency, or other form of business association between the parties.
|
||||
Neither party will have the power to bind the other party or incur any obligation on its behalf without the other party’s prior written consent.
|
||||
Nothing in this Agreement prevents either party from participating in similar arrangements with third parties.
|
||||
|
||||
15.2. No Assignment.
|
||||
NVIDIA may assign, delegate or transfer its rights or obligations under this Agreement by any means or operation of law.
|
||||
You may not, without NVIDIA’s prior written consent, assign, delegate or transfer any of your rights or obligations under this Agreement by any means or operation of law, and any attempt to do so is null and void.
|
||||
|
||||
15.3. No Waiver.
|
||||
No failure or delay by a party to enforce any term or obligation of this Agreement will operate as a waiver by that party, or prevent the enforcement of such term or obligation later.
|
||||
|
||||
15.4. Trade Compliance.
|
||||
You agree to comply with all applicable export, import, trade and economic sanctions laws and regulations, as amended, including without limitation U.S. Export Administration Regulations and Office of Foreign Assets Control regulations.
|
||||
You confirm (a) your understanding that export or reexport of certain NVIDIA products or technologies may require a license or other approval from appropriate authorities and (b) that you will not export or reexport any products or technology, directly or indirectly, without first obtaining any required license or other approval from appropriate authorities, (i) to any countries that are subject to any U.S. or local export restrictions (currently including, but not necessarily limited to, Belarus, Cuba, Iran, North Korea, Russia, Syria, the Region of Crimea, Donetsk People’s Republic Region and Luhansk People’s Republic Region);
|
||||
(ii) to any end-user who you know or have reason to know will utilize them in the design, development or production of nuclear, chemical or biological weapons, missiles, rocket systems, unmanned air vehicles capable of a maximum range of at least 300 kilometers, regardless of payload, or intended for military end-use, or any weapons of mass destruction;
|
||||
(iii) to any end-user who has been prohibited from participating in the U.S. or local export transactions by any governing authority;
|
||||
or (iv) to any known military or military-intelligence end-user or for any known military or military-intelligence end-use in accordance with U.S. trade compliance laws and regulations.
|
||||
|
||||
15.5. Government Rights.
|
||||
The Software, documentation and technology (“Protected Items”) are “Commercial products” as this term is defined at 48 C.F.R.
|
||||
2.101, consisting of “commercial computer software” and “commercial computer software documentation” as such terms are used in, respectively, 48 C.F.R.
|
||||
12.212 and 48 C.F.R. 227.7202 & 252.227-7014(a)(1). Before any Protected Items are supplied to the U.S. Government, you will (i) inform the U.S. Government in writing that the Protected Items are and must be treated as commercial computer software and commercial computer software documentation developed at private expense;
|
||||
(ii) inform the U.S. Government that the Protected Items are provided subject to the terms of the Agreement;
|
||||
and (iii) mark the Protected Items as commercial computer software and commercial computer software documentation developed at private expense.
|
||||
In no event will you permit the U.S. Government to acquire rights in Protected Items beyond those specified in 48 C.F.R.
|
||||
52.227-19(b)(1)-(2) or 252.227-7013(c) except as expressly approved by NVIDIA in writing.
|
||||
|
||||
15.6. Notices.
|
||||
Please direct your legal notices or other correspondence to legalnotices@nvidia.com with a copy mailed to NVIDIA Corporation, 2788 San Tomas Expressway, Santa Clara, California 95051, United States of America, Attention: Legal Department.
|
||||
If NVIDIA needs to contact you, you consent to receive the notices by email and agree that such notices will satisfy any legal communication requirements.
|
||||
|
||||
15.7. Severability.
|
||||
If a court of competent jurisdiction rules that a provision of this Agreement is unenforceable, that provision will be deemed modified to the extent necessary to make it enforceable and the remainder of this Agreement will continue in full force and effect.
|
||||
|
||||
15.8. Amendment.
|
||||
Any amendment to this Agreement must be in writing and signed by authorized representatives of both parties.
|
||||
|
||||
15.9. Construction.
|
||||
The headings in the Agreement are included solely for convenience and are not intended to affect the meaning or interpretation of the Agreement.
|
||||
As required by the context of the Agreement, the singular of a term includes the plural and vice versa.
|
||||
|
||||
15.10. Force Majeure.
|
||||
Neither party will be liable during any period where an event or circumstance prevents or delays that party from performing its obligations under this Agreement and that event or circumstance: (i) is not within the reasonable control of that party and is not the result of that party’s negligence, and (ii) cannot be overcome or avoided by that party using reasonably diligent efforts.
|
||||
|
||||
15.11. Entire Agreement.
|
||||
Regarding the subject matter of this Agreement, the parties agree that (a) this Agreement constitutes the entire and exclusive agreement between the parties and supersedes all prior and contemporaneous communications and (b) any additional or different terms or conditions, whether contained in purchase orders, order acknowledgments, invoices or otherwise, will not be binding and are null and void.
|
||||
|
||||
(v. May 8, 2025)
|
||||
17
python/CuTeDSL/base_dsl/__init__.py
Normal file
17
python/CuTeDSL/base_dsl/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
# Local module imports
|
||||
from .dsl import *
|
||||
from .runtime import *
|
||||
from ._mlir_helpers import lru_cache_ir
|
||||
from .env_manager import get_str_env_var, detect_gpu_arch
|
||||
|
||||
27
python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py
Normal file
27
python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides MLIR Dialect helper functions
|
||||
"""
|
||||
|
||||
from . import arith
|
||||
from .lru_cache_ir import lru_cache_ir
|
||||
|
||||
|
||||
__all__ = ["arith", "lru_cache_ir"]
|
||||
|
||||
try:
|
||||
from . import gpu
|
||||
|
||||
__all__.extend(["gpu"])
|
||||
except ImportError:
|
||||
pass
|
||||
691
python/CuTeDSL/base_dsl/_mlir_helpers/arith.py
Normal file
691
python/CuTeDSL/base_dsl/_mlir_helpers/arith.py
Normal file
@@ -0,0 +1,691 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides MLIR Arith Dialect helper functions
|
||||
"""
|
||||
|
||||
import array
|
||||
import numpy as np
|
||||
|
||||
from ..common import *
|
||||
from ..._mlir import ir # type: ignore
|
||||
from ..._mlir.extras import types as T # type: ignore
|
||||
from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore
|
||||
|
||||
from .lru_cache_ir import lru_cache_ir
|
||||
|
||||
# =============================================================================
|
||||
# Arith Dialect Helper functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def recast_type(src_type, res_elem_type) -> ir.Type:
|
||||
if isinstance(src_type, T.VectorType):
|
||||
if src_type.scalable:
|
||||
res_type = T.vector(
|
||||
*src_type.shape,
|
||||
res_elem_type,
|
||||
scalable=src_type.scalable,
|
||||
scalable_dims=src_type.scalable_dims,
|
||||
)
|
||||
else:
|
||||
res_type = T.vector(*src_type.shape, res_elem_type)
|
||||
elif isinstance(src_type, T.RankedTensorType):
|
||||
res_type = T.RankedTensorType.get(
|
||||
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
|
||||
)
|
||||
elif isinstance(src_type, T.UnrankedTensorType):
|
||||
res_type = T.UnrankedTensorType.get(element_type=res_elem_type)
|
||||
elif isinstance(src_type, T.MemRefType):
|
||||
res_type = T.MemRefType.get(
|
||||
element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
|
||||
)
|
||||
else:
|
||||
res_type = res_elem_type
|
||||
return res_type
|
||||
|
||||
|
||||
def is_scalar(ty) -> bool:
|
||||
return not isinstance(
|
||||
ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType)
|
||||
)
|
||||
|
||||
|
||||
def element_type(ty) -> ir.Type:
|
||||
if not is_scalar(ty):
|
||||
return ty.element_type
|
||||
else:
|
||||
return ty
|
||||
|
||||
|
||||
def is_narrow_precision(ty) -> bool:
|
||||
narrow_types = {
|
||||
T.f8E8M0FNU(),
|
||||
T.f8E4M3FN(),
|
||||
T.f8E4M3(),
|
||||
T.f8E5M2(),
|
||||
T.f8E4M3B11FNUZ(),
|
||||
T.f4E2M1FN(),
|
||||
T.f6E3M2FN(),
|
||||
T.f6E2M3FN(),
|
||||
}
|
||||
return ty in narrow_types
|
||||
|
||||
|
||||
def is_float_type(ty) -> bool:
|
||||
return (
|
||||
arith._is_float_type(ty)
|
||||
# TODO-upstream: prediction is not correct. Patch here and fix in upstream later
|
||||
or is_narrow_precision(ty)
|
||||
or ty in (T.bf16(), T.tf32())
|
||||
)
|
||||
|
||||
|
||||
def truncf_to_narrow(res_ty, src, loc, ip):
|
||||
res_elem_ty = element_type(res_ty)
|
||||
if res_elem_ty == T.f8E8M0FNU():
|
||||
rnd = nvgpu.RoundingMode.RP
|
||||
else:
|
||||
rnd = nvgpu.RoundingMode.RN
|
||||
return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def extf_from_narrow(res_ty, src, loc, ip):
|
||||
src_elem_ty = element_type(src.type)
|
||||
|
||||
# When source type is E8M0, temporary element type has to be bf16
|
||||
tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16()
|
||||
tmp_ty = recast_type(src.type, tmp_elem_ty)
|
||||
|
||||
# narrow -> bf16/f16 -> target type
|
||||
tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip)
|
||||
return arith.extf(res_ty, tmp, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def bitcast(src, res_elem_type, *, loc=None, ip=None):
|
||||
res_type = recast_type(src.type, res_elem_type)
|
||||
return arith.bitcast(res_type, src, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def cvtf(src, res_elem_type, *, loc=None, ip=None):
|
||||
src_elem_type = element_type(src.type)
|
||||
|
||||
if res_elem_type == src_elem_type:
|
||||
return src
|
||||
|
||||
res_type = recast_type(src.type, res_elem_type)
|
||||
|
||||
# Treat TF32 as F32 and use i32 as intermediate data
|
||||
# TODO-upstream: update arith to support tf32 <-> f32 conversion
|
||||
if src_elem_type == T.tf32():
|
||||
# tf32 -> i32
|
||||
tmp_type = recast_type(src.type, T.i32())
|
||||
src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip)
|
||||
# i32 -> f32
|
||||
src = bitcast(src, T.f32(), loc=loc, ip=ip)
|
||||
# f32 -> X with `cvtf` recursively
|
||||
return cvtf(src, res_elem_type, loc=loc, ip=ip)
|
||||
|
||||
if res_elem_type == T.tf32():
|
||||
# X -> f32 with `cvtf`` recursively
|
||||
tmp = cvtf(src, T.f32(), loc=loc, ip=ip)
|
||||
# f32 -> i32
|
||||
tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip)
|
||||
# i32 -> tf32
|
||||
return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip)
|
||||
|
||||
if res_elem_type.width > src_elem_type.width:
|
||||
if is_narrow_precision(src_elem_type):
|
||||
return extf_from_narrow(res_type, src, loc, ip)
|
||||
else:
|
||||
return arith.extf(res_type, src, loc=loc, ip=ip)
|
||||
else:
|
||||
tmp_mlir_type = recast_type(src.type, T.f32())
|
||||
|
||||
# f16 -- extf -> f32 -- truncf -> bf16
|
||||
# TODO-upstream: update arith to support bf16 <-> f16 conversion?
|
||||
if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or (
|
||||
src_elem_type == T.bf16() and res_elem_type == T.f16()
|
||||
):
|
||||
tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip)
|
||||
return arith.truncf(res_type, tmp, loc=loc, ip=ip)
|
||||
|
||||
# {f8, f6, f4} -> f16, f32, ...
|
||||
elif is_narrow_precision(res_elem_type):
|
||||
return truncf_to_narrow(res_type, src, loc, ip)
|
||||
else:
|
||||
return arith.truncf(res_type, src, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
|
||||
res_type = recast_type(src.type, res_elem_type)
|
||||
# TODO-upstream: update arith to support this kind of conversion
|
||||
if element_type(src.type) in (T.tf32(), T.bf16()):
|
||||
src = cvtf(src, T.f32(), loc=loc, ip=ip)
|
||||
|
||||
if signed:
|
||||
return arith.fptosi(res_type, src, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.fptoui(res_type, src, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
|
||||
res_type = recast_type(src.type, res_elem_type)
|
||||
|
||||
orig_res_type = res_type
|
||||
# TODO-upstream: update arith to support this kind of conversion
|
||||
if res_elem_type in (T.tf32(), T.bf16()):
|
||||
res_type = recast_type(src.type, T.f32())
|
||||
|
||||
if signed and element_type(src.type).width > 1:
|
||||
res = arith.sitofp(res_type, src, loc=loc, ip=ip)
|
||||
else:
|
||||
res = arith.uitofp(res_type, src, loc=loc, ip=ip)
|
||||
|
||||
if orig_res_type == res_type:
|
||||
return res
|
||||
|
||||
return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip)
|
||||
|
||||
|
||||
def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
||||
src_signed = a.signed
|
||||
dst_signed = dst_elem_type.signed
|
||||
src_width = element_type(a.type).width
|
||||
dst_width = dst_elem_type.width
|
||||
|
||||
dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type)
|
||||
|
||||
if dst_width == src_width:
|
||||
return a
|
||||
elif src_signed and not dst_signed:
|
||||
# Signed -> Unsigned
|
||||
if dst_width > src_width:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
elif src_signed == dst_signed:
|
||||
# Same signedness
|
||||
if dst_width > src_width:
|
||||
if src_signed and src_width > 1:
|
||||
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
# Unsigned -> Signed
|
||||
if dst_width > src_width:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
# For truncation from unsigned to signed, we need to handle overflow
|
||||
# First truncate to the target width
|
||||
trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
# Then reinterpret as signed
|
||||
if dst_signed:
|
||||
return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip)
|
||||
return trunc
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Arith Ops Emitter Helpers
|
||||
# - assuming type of lhs and rhs match each other
|
||||
# - op name matches python module operator
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None):
|
||||
"""
|
||||
This function provides simplified interface to upstream op builder
|
||||
arith.truncf(T.vector(shape, new_type), src)
|
||||
|
||||
is simplified as because it's element-wise op which can't change shape
|
||||
arith.truncf(new_type, src)
|
||||
"""
|
||||
if isinstance(src, ir.Value):
|
||||
src_ty = src.type
|
||||
else:
|
||||
src_ty = type(src).mlir_type
|
||||
src = src.ir_value()
|
||||
|
||||
src_elem_ty = element_type(src_ty)
|
||||
|
||||
if src_elem_ty == res_elem_ty:
|
||||
return src
|
||||
elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty):
|
||||
# float-to-float
|
||||
return cvtf(src, res_elem_ty, loc=loc, ip=ip)
|
||||
elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type(
|
||||
res_elem_ty
|
||||
):
|
||||
if src_elem_ty.width >= res_elem_ty.width:
|
||||
cast_op = arith.trunci
|
||||
else:
|
||||
if is_signed:
|
||||
cast_op = arith.extsi
|
||||
else:
|
||||
cast_op = arith.extui
|
||||
|
||||
res_ty = recast_type(src_ty, res_elem_ty)
|
||||
return cast_op(res_ty, src, loc=loc, ip=ip)
|
||||
elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty):
|
||||
return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip)
|
||||
elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty):
|
||||
return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"cast from {src_elem_ty} to {res_elem_ty} is not supported"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache_ir()
|
||||
def const(value, ty=None, *, loc=None, ip=None):
|
||||
"""
|
||||
Generates dynamic expression for constant values.
|
||||
"""
|
||||
from ..typing import Numeric, NumericMeta
|
||||
from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type
|
||||
|
||||
if isinstance(value, Numeric):
|
||||
value = value.value
|
||||
|
||||
# Early return
|
||||
if is_dynamic_expression(value) and (
|
||||
value.type.isinstance(value.type) or T.bool().isinstance(value.type)
|
||||
):
|
||||
return value
|
||||
|
||||
# Assume type
|
||||
if ty is None:
|
||||
if isinstance(value, float):
|
||||
ty = T.f32()
|
||||
elif isinstance(value, bool):
|
||||
ty = T.bool()
|
||||
elif isinstance(value, int):
|
||||
ty = T.i32()
|
||||
elif isinstance(value, np.ndarray):
|
||||
ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype))
|
||||
value = array.array(value.dtype.kind, value.flatten().tolist())
|
||||
else:
|
||||
raise DSLNotImplemented(f"{type(value)} is not supported")
|
||||
elif isinstance(ty, NumericMeta):
|
||||
ty = ty.mlir_type
|
||||
elif isinstance(ty, ir.Type):
|
||||
if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty):
|
||||
elem_ty = ty.element_type
|
||||
if isinstance(elem_ty, ir.IntegerType):
|
||||
attr = ir.IntegerAttr.get(elem_ty, value)
|
||||
else:
|
||||
attr = ir.FloatAttr.get(elem_ty, value)
|
||||
value = ir.DenseElementsAttr.get_splat(ty, attr)
|
||||
elif arith._is_float_type(ty) and isinstance(value, (bool, int)):
|
||||
value = float(value)
|
||||
elif arith._is_integer_like_type(ty) and isinstance(value, float):
|
||||
value = int(value)
|
||||
else:
|
||||
raise DSLNotImplemented(f"type {ty} is not supported")
|
||||
|
||||
return arith.constant(ty, value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _dispatch_to_rhs_r_op(op):
|
||||
"""Decorator that dispatches to the right-hand-side's reverse operation.
|
||||
|
||||
If the other operand is not an ArithValue or is a subclass (more specific)
|
||||
of ArithValue, this allows proper method resolution for binary operations.
|
||||
"""
|
||||
|
||||
def wrapper(self, other, **kwargs):
|
||||
if not isinstance(other, ArithValue):
|
||||
if not isinstance(other, (int, float, bool)):
|
||||
# allows to call other.__rmul__
|
||||
return NotImplemented
|
||||
|
||||
return op(self, other, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _binary_op(op):
|
||||
"""
|
||||
Decorator to check if the 'other' argument is an ArithValue.
|
||||
If not, returns NotImplemented.
|
||||
"""
|
||||
|
||||
def wrapper(self, other, **kwargs):
|
||||
# When reach this point, `self` must be cast to base `ArithValue` type
|
||||
if isinstance(other, (int, float, bool)):
|
||||
other = const(other, self.type).with_signedness(self.signed)
|
||||
|
||||
# Call the original function
|
||||
# If sub-class doesn't implement overloaded arithmetic, cast to base class
|
||||
return op(self, other, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Operator overloading
|
||||
@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid)
|
||||
@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid)
|
||||
@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid)
|
||||
@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid)
|
||||
@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid)
|
||||
@ir.register_value_caster(ir.Float8E5M2Type.static_typeid)
|
||||
@ir.register_value_caster(ir.Float8E4M3Type.static_typeid)
|
||||
@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid)
|
||||
@ir.register_value_caster(ir.BF16Type.static_typeid)
|
||||
@ir.register_value_caster(ir.F16Type.static_typeid)
|
||||
@ir.register_value_caster(ir.FloatTF32Type.static_typeid)
|
||||
@ir.register_value_caster(ir.F32Type.static_typeid)
|
||||
@ir.register_value_caster(ir.F64Type.static_typeid)
|
||||
@ir.register_value_caster(ir.IntegerType.static_typeid)
|
||||
@ir.register_value_caster(ir.VectorType.static_typeid)
|
||||
@ir.register_value_caster(ir.RankedTensorType.static_typeid)
|
||||
class ArithValue(ir.Value):
|
||||
"""Overloads operators for MLIR's Arith dialects binary operations."""
|
||||
|
||||
def __init__(self, v, signed: Union[bool, None] = None):
|
||||
if isinstance(v, int):
|
||||
v = arith.constant(self.type, v)
|
||||
super().__init__(v)
|
||||
|
||||
elem_ty = element_type(self.type)
|
||||
self.is_float = arith._is_float_type(elem_ty)
|
||||
# arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
|
||||
self.signed = signed and elem_ty.width > 1
|
||||
|
||||
def with_signedness(self, signed: Union[bool, None]):
|
||||
return type(self)(self, signed)
|
||||
|
||||
def __neg__(self, *, loc=None, ip=None):
|
||||
if self.type == T.bool():
|
||||
raise TypeError(
|
||||
"Negation, the operator `-` is not supported for boolean type"
|
||||
)
|
||||
|
||||
if self.is_float:
|
||||
return arith.negf(self, loc=loc, ip=ip)
|
||||
else:
|
||||
c0 = arith.constant(self.type, 0, loc=loc, ip=ip)
|
||||
return arith.subi(c0, self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float and other.is_float:
|
||||
return math.powf(self, other, loc=loc, ip=ip)
|
||||
elif self.is_float and not other.is_float:
|
||||
return math.fpowi(self, other, loc=loc, ip=ip)
|
||||
elif not self.is_float and other.is_float:
|
||||
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
|
||||
rhs = cvtf(other, T.f32(), loc=loc, ip=ip)
|
||||
return math.powf(lhs, rhs, loc=loc, ip=ip)
|
||||
elif not self.is_float and not other.is_float:
|
||||
return math.ipowi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
raise DSLNotImplemented(f"Unsupported '{self} ** {other}'")
|
||||
|
||||
@_binary_op
|
||||
def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__pow__(self, loc=loc, ip=ip)
|
||||
|
||||
# arith operators
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __add__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.addf(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.addi(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.subf(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.subi(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.mulf(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.muli(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.divf(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
|
||||
rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip)
|
||||
return arith.divf(lhs, rhs, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
q = arith.divf(self, other, loc=loc, ip=ip)
|
||||
return math.floor(q, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.floordivsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.divui(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.remf(self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.remsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.remui(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__add__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__sub__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__mul__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__truediv__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__floordiv__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__mod__(self, loc=loc, ip=ip)
|
||||
|
||||
# Comparison operators (comparison doesn't have right-hand-side variants)
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
# In Python, bool(float("nan")) is True, so use unordered comparison here
|
||||
return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
|
||||
|
||||
# Unary operators
|
||||
def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.xori(self, arith.const(self.type, -1))
|
||||
|
||||
# Bitwise operations
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __and__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.andi(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __or__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.ori(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.xori(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.signed:
|
||||
return arith.shrsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.shrui(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.shli(self, other, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.andi(other, self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.ori(other, self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.xori(other, self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__rshift__(self, loc=loc, ip=ip)
|
||||
|
||||
@_binary_op
|
||||
def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
return other.__lshift__(self, loc=loc, ip=ip)
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
def __str__(self):
|
||||
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def _min(lhs, rhs, *, loc=None, ip=None):
|
||||
"""
|
||||
This function provides a unified interface for building arith min
|
||||
|
||||
Assuming the operands have the same type
|
||||
"""
|
||||
from ..dsl import is_dynamic_expression
|
||||
|
||||
if not is_dynamic_expression(lhs):
|
||||
if not is_dynamic_expression(rhs):
|
||||
return min(lhs, rhs)
|
||||
else:
|
||||
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
|
||||
else:
|
||||
if not is_dynamic_expression(rhs):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.minui(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.minimumf(lhs, rhs, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _max(lhs, rhs, *, loc=None, ip=None):
|
||||
"""
|
||||
This function provides a unified interface for building arith max
|
||||
|
||||
Assuming the operands have the same type
|
||||
"""
|
||||
from ..dsl import is_dynamic_expression
|
||||
|
||||
if not is_dynamic_expression(lhs):
|
||||
if not is_dynamic_expression(rhs):
|
||||
return max(lhs, rhs)
|
||||
else:
|
||||
lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
|
||||
else:
|
||||
if not is_dynamic_expression(rhs):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.maxui(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.maximumf(lhs, rhs, loc=loc, ip=ip)
|
||||
64
python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py
Normal file
64
python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides MLIR GPU Dialect helper functions
|
||||
"""
|
||||
|
||||
|
||||
from ..._mlir import ir
|
||||
from ..._mlir.dialects import gpu, arith, scf
|
||||
from ..._mlir.extras import types as T
|
||||
|
||||
from ..common import *
|
||||
|
||||
# =============================================================================
|
||||
# GPU Dialect Helper functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_async_token():
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
token = gpu.wait(token_ty, [])
|
||||
return token
|
||||
|
||||
|
||||
def printf(fmt, *args, threadNumber=-1):
|
||||
"""Generate gpu.printf OP predicated on threadNumber"""
|
||||
type_formats = []
|
||||
for arg in args:
|
||||
ty_format = None
|
||||
if ir.IndexType.isinstance(arg.type):
|
||||
ty_format = "%llu"
|
||||
if ir.IntegerType.isinstance(arg.type):
|
||||
width = ir.IntegerType(arg.type).width
|
||||
if width == 64:
|
||||
ty_format = "%llu"
|
||||
elif width == 32:
|
||||
ty_format = "%d"
|
||||
elif width == 1:
|
||||
ty_format = "%i"
|
||||
if ir.F32Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
if ty_format is None:
|
||||
raise DSLNotImplemented(arg.type)
|
||||
type_formats.append(ty_format)
|
||||
if threadNumber == -1:
|
||||
gpu.printf(fmt.format(*type_formats) + "\n", args)
|
||||
if threadNumber != -1:
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
predicate = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber)
|
||||
)
|
||||
if_op = scf.IfOp(predicate)
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
gpu.printf(fmt.format(*type_formats) + "\n", args)
|
||||
scf.yield_([])
|
||||
76
python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py
Normal file
76
python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides @lru_cache_ir
|
||||
It extends functools.lru_cache with IR Context awareness.
|
||||
|
||||
Example usage:
|
||||
from cutlass import ir
|
||||
from lru_cache_ir import lru_cache_ir
|
||||
|
||||
@lru_cache_ir(ir, maxsize=128, typed=False)
|
||||
def make_layout(...):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from functools import lru_cache, wraps
|
||||
|
||||
from ..._mlir import ir # type: ignore
|
||||
|
||||
|
||||
def get_ir_context(func):
|
||||
"""
|
||||
Return the context for given func called under ir.
|
||||
Currently the context includes MLIRContext and InsertionPoint.
|
||||
"""
|
||||
try:
|
||||
if ir:
|
||||
return (ir.Context.current, ir.InsertionPoint.current)
|
||||
else:
|
||||
return None
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def lru_cache_ir(maxsize=128, typed=True):
|
||||
"""
|
||||
Applies an LRU cache to a given function, with awareness of IR context.
|
||||
|
||||
Usage is similar to functools.lru_cache while taking `ir` as required argument.
|
||||
|
||||
:param ir: The IR object from which to derive the context by `get_ir_context`
|
||||
:param maxsize: Max cache size, same as functools.lru_cache
|
||||
:param typed: Whether params are type-sensitive, default to True as IR is type-sensitive
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# Use functools.lru_cache with a custom wrapper to control the key generation
|
||||
@lru_cache(maxsize=maxsize, typed=typed)
|
||||
def cached_func(context, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
# Call the cached function with the context
|
||||
return cached_func(get_ir_context(func), *args, **kwargs)
|
||||
except (RuntimeError, TypeError):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Expose cache-related methods for introspection
|
||||
wrapper.cache_clear = cached_func.cache_clear
|
||||
wrapper.cache_info = cached_func.cache_info
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
34
python/CuTeDSL/base_dsl/_mlir_helpers/op.py
Normal file
34
python/CuTeDSL/base_dsl/_mlir_helpers/op.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides MLIR's OP helper functions
|
||||
"""
|
||||
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from ..._mlir import ir
|
||||
|
||||
|
||||
def dsl_user_op(opFunc):
|
||||
@wraps(opFunc)
|
||||
def wrapper(*args, **kwargs):
|
||||
loc = kwargs.pop("loc", None)
|
||||
if loc is None:
|
||||
frame = inspect.currentframe().f_back
|
||||
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
|
||||
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
|
||||
res_or_list = opFunc(*args, **kwargs, loc=loc)
|
||||
return res_or_list
|
||||
|
||||
return wrapper
|
||||
584
python/CuTeDSL/base_dsl/ast_helpers.py
Normal file
584
python/CuTeDSL/base_dsl/ast_helpers.py
Normal file
@@ -0,0 +1,584 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides helper functions that are generated by the preprocessor.
|
||||
The preprocessor read through python's ast and changes the input code.
|
||||
"""
|
||||
|
||||
from typing import Callable, Iterator, Optional, overload
|
||||
|
||||
from .utils.logger import log
|
||||
from .common import *
|
||||
|
||||
from ._mlir_helpers.arith import ArithValue
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
The Executor class handles dynamic and compile-time (constexpr) execution
|
||||
of "for" loops and "if-else-elif" statements.
|
||||
|
||||
Methods:
|
||||
set_functions: Assigns the functions for checking loop bounds and
|
||||
conditional evaluation.
|
||||
|
||||
for_dynamic: Generates MLIR for OP
|
||||
for_constexpr: Executes a for loop at JIT compile-time
|
||||
for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.
|
||||
|
||||
if_dynamic: Generates MLIR if OP
|
||||
if_constexpr: Executes a if at JIT compile-time by python interpreter
|
||||
if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._is_dynamic_expression = None
|
||||
self._loop_execute_range_dynamic = None
|
||||
self._if_dynamic = None
|
||||
self._while_dynamic = None
|
||||
|
||||
def set_functions(
|
||||
self,
|
||||
is_dynamic_expression: Callable,
|
||||
loop_execute_range_dynamic: Callable,
|
||||
if_dynamic: Callable,
|
||||
while_dynamic: Callable,
|
||||
):
|
||||
self._is_dynamic_expression = is_dynamic_expression
|
||||
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
||||
self._if_dynamic = if_dynamic
|
||||
self._while_dynamic = while_dynamic
|
||||
|
||||
@staticmethod
|
||||
def convert_to_list(x):
|
||||
"""This function is used to convert x to a list.
|
||||
If x is None, return an empty list.
|
||||
If x is not a list, return a list containing x.
|
||||
Otherwise, return x itself.
|
||||
"""
|
||||
if x is None:
|
||||
return []
|
||||
if not isinstance(x, list):
|
||||
return [x]
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def converge_ret_val(res):
|
||||
"""This function is used to converge res (the return value) of the function.
|
||||
If res is None, return None.
|
||||
If res is a list and has only one element, return the element.
|
||||
Otherwise, return res itself.
|
||||
"""
|
||||
if res is None:
|
||||
return res
|
||||
elif isinstance(res, list) and len(res) == 1:
|
||||
return res[0]
|
||||
return res
|
||||
|
||||
def for_dynamic(
|
||||
self,
|
||||
func: Callable,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
iter_arg_names: list,
|
||||
unroll=bool,
|
||||
unroll_full=int,
|
||||
):
|
||||
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
return self._loop_execute_range_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def for_constexpr(
|
||||
func: Callable,
|
||||
start: int,
|
||||
stop: int,
|
||||
step: int,
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
):
|
||||
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
loop_results = iter_args
|
||||
log().debug("iter_args [%s]", iter_args)
|
||||
for i in range(start, stop, step):
|
||||
log().debug("i [%s] iter_args [%s]", i, iter_args)
|
||||
loop_results = func(i, *used_args, *loop_results)
|
||||
log().debug("loop_results [%s]", loop_results)
|
||||
if loop_results is None:
|
||||
loop_results = []
|
||||
if not isinstance(loop_results, list):
|
||||
loop_results = [loop_results]
|
||||
|
||||
log().debug("done loop_results [%s]", loop_results)
|
||||
return Executor.converge_ret_val(loop_results)
|
||||
|
||||
def for_execute(
|
||||
self,
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
is_range_constexpr=None,
|
||||
):
|
||||
assert (
|
||||
self._loop_execute_range_dynamic and self._is_dynamic_expression
|
||||
), "Functions must be set before execution."
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
any_dynamic_expression = (
|
||||
self._is_dynamic_expression(start)
|
||||
or self._is_dynamic_expression(stop)
|
||||
or self._is_dynamic_expression(step)
|
||||
)
|
||||
|
||||
if is_range_constexpr is None:
|
||||
if not any_dynamic_expression:
|
||||
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
||||
else:
|
||||
return self.for_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if is_range_constexpr:
|
||||
if any_dynamic_expression:
|
||||
raise DSLRuntimeError(
|
||||
"Loop bounds must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.for_constexpr(func, start, stop, step, used_args, iter_args)
|
||||
|
||||
# MLIR generation
|
||||
return self.for_dynamic(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
)
|
||||
|
||||
def if_dynamic(
|
||||
self,
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
):
|
||||
return self._if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def if_constexpr(
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
):
|
||||
if pred:
|
||||
log().debug(" running then block [%s]", yield_args)
|
||||
res = then_block(*used_args, *yield_args)
|
||||
log().debug("result [%s]", res)
|
||||
return Executor.converge_ret_val(res)
|
||||
elif else_block is not None:
|
||||
log().debug("running else [%s]", yield_args)
|
||||
res = else_block(*used_args, *yield_args)
|
||||
log().debug("result [%s]", res)
|
||||
return Executor.converge_ret_val(res)
|
||||
|
||||
def if_execute(
|
||||
self,
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
if_constexpr=None,
|
||||
):
|
||||
assert (
|
||||
self._if_dynamic and self._is_dynamic_expression
|
||||
), "Functions must be set before execution."
|
||||
|
||||
is_if_constexpr = not self._is_dynamic_expression(pred)
|
||||
if if_constexpr is None:
|
||||
if is_if_constexpr:
|
||||
return self.if_constexpr(
|
||||
pred, then_block, else_block, used_args, yield_args
|
||||
)
|
||||
else:
|
||||
return self.if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if if_constexpr:
|
||||
if not is_if_constexpr:
|
||||
raise DSLRuntimeError(
|
||||
"If predicate must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.if_constexpr(
|
||||
pred, then_block, else_block, used_args, yield_args
|
||||
)
|
||||
|
||||
# MLIR generation
|
||||
return self.if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
)
|
||||
|
||||
def while_dynamic(
|
||||
self,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
):
|
||||
return self._while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def while_constexpr(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
):
|
||||
log().debug(
|
||||
"while_constexpr begin %s", while_before_block.__qualname__
|
||||
)
|
||||
cond, loop_results = while_before_block(*used_args, *yield_args)
|
||||
while cond:
|
||||
loop_results = Executor.convert_to_list(loop_results)
|
||||
log().debug(
|
||||
"calling while_after [%s], [%s]",
|
||||
used_args,
|
||||
loop_results,
|
||||
)
|
||||
loop_results = while_after_block(*used_args, *loop_results)
|
||||
log().debug(
|
||||
"while after [%s]", loop_results
|
||||
)
|
||||
loop_results = Executor.convert_to_list(loop_results)
|
||||
log().debug(
|
||||
"calling while_before [%s], [%s]",
|
||||
used_args,
|
||||
loop_results,
|
||||
)
|
||||
cond, loop_results = while_before_block(*used_args, *loop_results)
|
||||
log().debug(
|
||||
"while_before cond, results [%s], [%s]",
|
||||
cond,
|
||||
loop_results,
|
||||
)
|
||||
|
||||
log().debug(
|
||||
"while_constexpr results %s", loop_results
|
||||
)
|
||||
return Executor.converge_ret_val(loop_results)
|
||||
|
||||
def while_execute(
|
||||
self,
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
while_constexpr=None,
|
||||
):
|
||||
assert (
|
||||
self._while_dynamic and self._is_dynamic_expression
|
||||
), "Functions must be set before execution."
|
||||
|
||||
is_while_constexpr = not self._is_dynamic_expression(pred)
|
||||
|
||||
# Ensure bounds are compile-time constants for constexpr execution
|
||||
if while_constexpr:
|
||||
if not is_while_constexpr:
|
||||
raise DSLRuntimeError(
|
||||
"While predicate must be constexpr (compile-time constants)"
|
||||
)
|
||||
return self.while_constexpr(
|
||||
while_before_block, while_after_block, used_args, yield_args
|
||||
)
|
||||
|
||||
# MLIR generation
|
||||
return self.while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Decorator
|
||||
# =============================================================================
|
||||
|
||||
executor = Executor()
|
||||
|
||||
|
||||
def loop_selector(
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
constexpr=None,
|
||||
):
|
||||
log().info(
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
unroll,
|
||||
unroll_full,
|
||||
constexpr,
|
||||
)
|
||||
from .typing import Integer, Numeric
|
||||
|
||||
def _maybe_upcast(value):
|
||||
if isinstance(value, Integer):
|
||||
value = value.ir_value()
|
||||
|
||||
return value
|
||||
|
||||
start = _maybe_upcast(start)
|
||||
stop = _maybe_upcast(stop)
|
||||
step = _maybe_upcast(step)
|
||||
|
||||
def ir_loop(func):
|
||||
return executor.for_execute(
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
constexpr,
|
||||
)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def if_selector(pred, used_args=[], yield_args=[]):
|
||||
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
||||
# Handle Numeric types here?
|
||||
|
||||
from .typing import Numeric
|
||||
|
||||
if isinstance(pred, Numeric):
|
||||
pred = pred.value
|
||||
|
||||
def ir_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def while_selector(pred, used_args=[], yield_args=[]):
|
||||
def ir_while_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
|
||||
return ir_while_loop
|
||||
|
||||
|
||||
def while_executor(
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
constexpr=None,
|
||||
):
|
||||
return executor.while_execute(
|
||||
pred,
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
constexpr,
|
||||
)
|
||||
|
||||
|
||||
def if_executor(
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
constexpr=None,
|
||||
):
|
||||
return executor.if_execute(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Range
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class range_dynamic:
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False):
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
|
||||
pass
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
||||
|
||||
|
||||
class range_constexpr:
|
||||
def __init__(self, *args):
|
||||
if len(args) == 1:
|
||||
self.start = 0
|
||||
self.stop = args[0]
|
||||
self.step = 1
|
||||
elif len(args) == 2:
|
||||
self.start, self.stop = args
|
||||
self.step = 1
|
||||
elif len(args) == 3:
|
||||
self.start, self.stop, self.step = args
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"range_constexpr supports up to 3 arguments (start, stop, step)"
|
||||
)
|
||||
# Ensure the arguments are compile-time constants (if required)
|
||||
for arg_name, arg_value in [
|
||||
("step", self.step),
|
||||
("start", self.start),
|
||||
("stop", self.stop),
|
||||
]:
|
||||
if executor._is_dynamic_expression(arg_value):
|
||||
raise DSLRuntimeError(
|
||||
f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
|
||||
f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
|
||||
f"will handle them during runtime. ",
|
||||
suggestion="Use `range` instead of `range_constexpr`.",
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
current = self.start
|
||||
while current < self.stop:
|
||||
yield current
|
||||
current += self.step
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# If expressions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def const_expr(expression):
|
||||
if executor._is_dynamic_expression(expression):
|
||||
raise DSLRuntimeError(
|
||||
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
||||
context={
|
||||
"const_expr": "Accepts only constexpr (compile-time constant)",
|
||||
"If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
|
||||
"If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
|
||||
},
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def dynamic_expr(expression):
|
||||
raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Assertion & casting
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def assert_executor(test, msg=None):
|
||||
from .typing import Numeric
|
||||
|
||||
fail = False
|
||||
# Implicit convert dynamic expression to bool is not allowed
|
||||
# So here explicitly do a None check
|
||||
if test is not None and executor._is_dynamic_expression(test):
|
||||
if isinstance(test, Numeric):
|
||||
try:
|
||||
test = test.to(bool)
|
||||
except:
|
||||
fail = True
|
||||
else:
|
||||
fail = True
|
||||
|
||||
if not fail:
|
||||
assert test, msg
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please replace with runtime assert."
|
||||
)
|
||||
|
||||
|
||||
def bool_cast(value):
|
||||
if executor._is_dynamic_expression(value):
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
||||
)
|
||||
return bool(value)
|
||||
1459
python/CuTeDSL/base_dsl/ast_preprocessor.py
Normal file
1459
python/CuTeDSL/base_dsl/ast_preprocessor.py
Normal file
File diff suppressed because it is too large
Load Diff
154
python/CuTeDSL/base_dsl/cache_helpers.py
Normal file
154
python/CuTeDSL/base_dsl/cache_helpers.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides jit cache load/dump helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import random
|
||||
import tempfile
|
||||
import pwd
|
||||
import time
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
from .utils.logger import log
|
||||
from .jit_executor import JitExecutor
|
||||
|
||||
from .._mlir import ir
|
||||
|
||||
# =============================================================================
|
||||
# Jit Cache Helper functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_current_user():
|
||||
# Try to get the user from the environment variable first
|
||||
user = os.getenv("USER") or os.getenv("USERNAME")
|
||||
if not user:
|
||||
# Fallback for Unix-like systems
|
||||
user = pwd.getpwuid(os.getuid()).pw_name
|
||||
return user
|
||||
|
||||
|
||||
try:
|
||||
default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
|
||||
except Exception as e:
|
||||
# If all else fails, provide a default fallback path
|
||||
default_generated_ir_path = "/tmp/cutlass_python_cache/"
|
||||
print(f"Could not determine user, using default path. Error: {e}")
|
||||
|
||||
|
||||
def load_ir(file, asBytecode=False):
|
||||
"""Load generated IR from a file."""
|
||||
assert "mlir" in file
|
||||
func_name = file.split(".mlir")[0].split("dsl_")[-1]
|
||||
with ir.Context() as ctx:
|
||||
with open(file, "rb" if asBytecode else "r") as f:
|
||||
module = ir.Module.parse(f.read())
|
||||
|
||||
return func_name, module
|
||||
|
||||
|
||||
def make_unique_filename(fpath: Path, new_ext: str = None) -> Path:
|
||||
"""Generate a unique filename with an optional new extension."""
|
||||
random_part = random.randint(0, 999999)
|
||||
timestamp = time.time()
|
||||
hash_input = f"{fpath}_{timestamp}_{random_part}".encode()
|
||||
hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability
|
||||
stem_with_hash = f"{fpath.stem}_{hash_code}"
|
||||
return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix)
|
||||
|
||||
|
||||
def save_ir(
|
||||
dsl_name: str,
|
||||
module: object,
|
||||
fname: str,
|
||||
isTemp: bool = False,
|
||||
asBytecode: bool = False,
|
||||
) -> str:
|
||||
"""Save generated IR to a file."""
|
||||
initial_name = f"{dsl_name.lower()}_{fname}.mlir"
|
||||
save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd())
|
||||
save_fname = save_path / initial_name
|
||||
# Random ID to avoid any collisions
|
||||
rnd_id = str(uuid.uuid4())
|
||||
pid = os.getpid()
|
||||
# use temp dir to be robust against program interruptions
|
||||
temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}")
|
||||
# If the process exits abnormally, may leave a temporary folder. Needs to be removed manually.
|
||||
os.makedirs(temp_dir, exist_ok=False)
|
||||
temp_fname = os.path.join(temp_dir, initial_name)
|
||||
|
||||
if asBytecode:
|
||||
with open(temp_fname, "wb") as f:
|
||||
module.operation.write_bytecode(f)
|
||||
else:
|
||||
with open(temp_fname, "w") as f:
|
||||
print(module, file=f)
|
||||
# os.replace is guaranteed to be atomic on POSIX systems if it succeeds
|
||||
# so filepath cannot see a partial write
|
||||
os.replace(temp_fname, save_fname)
|
||||
os.removedirs(temp_dir)
|
||||
log().debug("Generated IR saved into %s", save_fname)
|
||||
return save_fname
|
||||
|
||||
|
||||
def check_func_name(jit_cache, func_name):
|
||||
if not func_name in jit_cache:
|
||||
jit_cache[func_name] = JitExecutor(None, None, None, None, None, None)
|
||||
return jit_cache
|
||||
|
||||
|
||||
def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
|
||||
"""Load cache from a directory path."""
|
||||
if not os.path.exists(path):
|
||||
return dict()
|
||||
files = os.listdir(path)
|
||||
jit_cache = dict()
|
||||
try:
|
||||
for idx, file in enumerate(files):
|
||||
if idx >= int(cache_limit):
|
||||
break
|
||||
# identify dsl prefix
|
||||
if not file.startswith(f"{dsl_name.lower()}"):
|
||||
continue
|
||||
if ".mlir" in file:
|
||||
func_name, ir_module = load_ir(
|
||||
os.path.join(path, file), asBytecode=True
|
||||
)
|
||||
jit_cache = check_func_name(jit_cache, func_name)
|
||||
jit_cache[func_name].ir_module = ir_module
|
||||
except Exception as e:
|
||||
print(f"{dsl_name} failed with loading generated IR cache.", e)
|
||||
jit_cache = dict()
|
||||
return jit_cache
|
||||
|
||||
|
||||
def dump_cache_to_path(
|
||||
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
|
||||
):
|
||||
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
original_path = os.getcwd()
|
||||
try:
|
||||
os.chdir(path)
|
||||
for idx, [key, value] in enumerate(jit_cache.items()):
|
||||
if idx >= int(cache_limit):
|
||||
break
|
||||
save_ir(dsl_name, value.ir_module, key, asBytecode=True)
|
||||
except Exception as e:
|
||||
print(f"{dsl_name} failed with caching generated IR", e)
|
||||
finally:
|
||||
os.chdir(original_path)
|
||||
268
python/CuTeDSL/base_dsl/common.py
Normal file
268
python/CuTeDSL/base_dsl/common.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
"""
|
||||
This module provides a Exception classes DSL class for any Dialect.
|
||||
"""
|
||||
|
||||
|
||||
# Add color codes at the top of the file after imports
|
||||
class Colors:
|
||||
"""ANSI color codes for error messages"""
|
||||
|
||||
RED = "\033[91m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DSL Exceptions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DSLBaseError(Exception):
|
||||
"""
|
||||
Base exception for DSL-related errors.
|
||||
Provides optional contextual metadata to aid in debugging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
line: Optional[int] = None,
|
||||
snippet: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
error_code: Optional[Union[str, int]] = None,
|
||||
context: Optional[Union[Dict[str, Any], str]] = None,
|
||||
suggestion: Optional[str] = None,
|
||||
cause: Optional[BaseException] = None,
|
||||
) -> None:
|
||||
self.message = message
|
||||
self.line = line
|
||||
self.filename = filename
|
||||
self.snippet = snippet
|
||||
self.error_code = error_code
|
||||
self.context = context
|
||||
self.suggestion = suggestion
|
||||
self.cause = cause
|
||||
|
||||
super().__init__(self._format_message())
|
||||
|
||||
def _format_message(self):
|
||||
"""
|
||||
Formats the complete error message with available metadata.
|
||||
Override this in subclasses if you want to change formatting logic.
|
||||
"""
|
||||
parts = [f"{self.__class__.__name__}: {self.message}"]
|
||||
|
||||
if self.error_code is not None:
|
||||
parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n")
|
||||
|
||||
if self.line is not None:
|
||||
parts.append(f" Line: {self.line}")
|
||||
|
||||
if self.filename is not None:
|
||||
parts.append(f" File: {self.filename}")
|
||||
|
||||
if self.snippet:
|
||||
# Optionally truncate long snippets for readability
|
||||
parts.append(f" Snippet: \n {self.snippet}")
|
||||
|
||||
if self.cause:
|
||||
parts.append(f" Caused exception: {self.cause}")
|
||||
|
||||
if self.context:
|
||||
if isinstance(self.context, dict):
|
||||
parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n")
|
||||
for key, value in self.context.items():
|
||||
parts.append(f" {key}: {value}")
|
||||
else:
|
||||
parts.append(
|
||||
f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}"
|
||||
)
|
||||
|
||||
if self.suggestion:
|
||||
parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}")
|
||||
if isinstance(self.suggestion, (list, tuple)):
|
||||
for suggestion in self.suggestion:
|
||||
parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}")
|
||||
else:
|
||||
parts.append(f" {self.suggestion}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
class DSLRuntimeError(DSLBaseError):
|
||||
"""
|
||||
Raised when an error occurs during JIT-time code generation in the DSL.
|
||||
"""
|
||||
|
||||
# Inherits all logic from DSLBaseError; override methods if you need
|
||||
# specialized behavior or formatting for runtime errors.
|
||||
pass
|
||||
|
||||
|
||||
def _get_friendly_cuda_error_message(error_code, error_name):
|
||||
# Avoid circular dependency
|
||||
from .runtime.cuda import get_device_info
|
||||
|
||||
"""Get a user-friendly error message for common CUDA errors."""
|
||||
# Strip the byte string markers if present
|
||||
if isinstance(error_name, bytes):
|
||||
error_name = error_name.decode("utf-8")
|
||||
elif (
|
||||
isinstance(error_name, str)
|
||||
and error_name.startswith("b'")
|
||||
and error_name.endswith("'")
|
||||
):
|
||||
error_name = error_name[2:-1]
|
||||
|
||||
# Add target architecture info
|
||||
target_arch = os.getenv("CUTE_DSL_ARCH", "unknown")
|
||||
|
||||
error_messages = {
|
||||
"CUDA_ERROR_INVALID_SOURCE": (
|
||||
f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n"
|
||||
),
|
||||
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
|
||||
f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n"
|
||||
),
|
||||
"CUDA_ERROR_OUT_OF_MEMORY": (
|
||||
f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n"
|
||||
),
|
||||
"CUDA_ERROR_INVALID_DEVICE": (
|
||||
f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n"
|
||||
),
|
||||
"CUDA_ERROR_NOT_INITIALIZED": (
|
||||
f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n"
|
||||
),
|
||||
"CUDA_ERROR_INVALID_VALUE": (
|
||||
f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n"
|
||||
f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}"
|
||||
),
|
||||
}
|
||||
|
||||
error_suggestions = {
|
||||
"CUDA_ERROR_INVALID_SOURCE": (
|
||||
f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
f"2. Clear the compilation cache and regenerate the kernel",
|
||||
f"3. Check CUDA toolkit installation",
|
||||
),
|
||||
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
|
||||
f"Set env CUTE_DSL_ARCH to match your GPU architecture",
|
||||
),
|
||||
"CUDA_ERROR_OUT_OF_MEMORY": (
|
||||
f"1. Reduce batch size",
|
||||
f"2. Reduce model size",
|
||||
f"3. Free unused GPU memory",
|
||||
),
|
||||
"CUDA_ERROR_INVALID_DEVICE": (
|
||||
f"1. Check if CUDA device is properly initialized",
|
||||
f"2. Verify GPU is detected: nvidia-smi",
|
||||
f"3. Check CUDA_VISIBLE_DEVICES environment variable",
|
||||
),
|
||||
"CUDA_ERROR_NOT_INITIALIZED": (
|
||||
f"1. Check CUDA driver installation",
|
||||
f"2. call `cuda.cuInit(0)` before any other CUDA operation",
|
||||
f"3. Run nvidia-smi to confirm GPU status",
|
||||
),
|
||||
"CUDA_ERROR_INVALID_VALUE": (
|
||||
f"1. Your GPU model",
|
||||
f"2. SM ARCH setting",
|
||||
f"3. Steps to reproduce",
|
||||
),
|
||||
}
|
||||
|
||||
message = error_messages.get(
|
||||
error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Add debug information
|
||||
debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n"
|
||||
debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n"
|
||||
debug_info += (
|
||||
f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get GPU information using CUDA Python API
|
||||
debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n"
|
||||
gpu_info = get_device_info()
|
||||
debug_info += gpu_info.pretty_str()
|
||||
|
||||
if target_arch and gpu_info.compatible_archs:
|
||||
debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n"
|
||||
|
||||
if target_arch not in gpu_info.compatible_archs:
|
||||
debug_info += (
|
||||
f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n"
|
||||
f"💡 Please use one of SM ARCHs: "
|
||||
f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n"
|
||||
)
|
||||
elif target_arch != gpu_info.sm_arch:
|
||||
debug_info += (
|
||||
f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n"
|
||||
f"• Current: {target_arch}\n"
|
||||
f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n"
|
||||
)
|
||||
else:
|
||||
debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n"
|
||||
|
||||
except Exception as e:
|
||||
debug_info += (
|
||||
f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}"
|
||||
)
|
||||
|
||||
return message, debug_info, error_suggestions.get(error_name, "")
|
||||
|
||||
|
||||
class DSLCudaRuntimeError(DSLBaseError):
|
||||
"""
|
||||
Raised when an error occurs during CUDA runtime code generation in the DSL.
|
||||
"""
|
||||
|
||||
# Inherits all logic from DSLRuntimeError; override methods if you need
|
||||
# specialized behavior or formatting for runtime errors.
|
||||
def __init__(self, error_code, error_name) -> None:
|
||||
self._error_code = error_code
|
||||
self._error_name = error_name
|
||||
message, debug_info, suggestion = _get_friendly_cuda_error_message(
|
||||
error_code, error_name
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
message, error_code=error_code, context=debug_info, suggestion=suggestion
|
||||
)
|
||||
|
||||
|
||||
class DSLAstPreprocessorError(DSLBaseError):
|
||||
"""
|
||||
Raised when an error occurs during AST preprocessing or visiting in the DSL.
|
||||
"""
|
||||
|
||||
# Same approach: You could override _format_message if you want
|
||||
# to emphasize AST node details or anything specific to preprocessing.
|
||||
pass
|
||||
|
||||
|
||||
class DSLNotImplemented(DSLBaseError):
|
||||
"""
|
||||
Raised when a feature of the DSL is not implemented yet.
|
||||
"""
|
||||
|
||||
# Useful for stubs in your DSL that you plan to implement in the future.
|
||||
pass
|
||||
221
python/CuTeDSL/base_dsl/compiler.py
Normal file
221
python/CuTeDSL/base_dsl/compiler.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides a class that compiles generated IR using MLIR's PassManager
|
||||
and executes it using MLIR's ExecutionEngine.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Optional, Tuple
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
from .common import DSLRuntimeError
|
||||
|
||||
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
|
||||
from .._mlir import ir
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Compiler Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CompilationError(RuntimeError):
|
||||
"""Custom error class for compilation failures"""
|
||||
|
||||
# Add ANSI color codes
|
||||
RED = "\033[91m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
nvvm_error: Optional[str] = None,
|
||||
ir_context: Optional[str] = None,
|
||||
cuda_toolkit: Optional[str] = None,
|
||||
arch: Optional[str] = None,
|
||||
):
|
||||
self.nvvm_error = nvvm_error
|
||||
self.ir_context = ir_context
|
||||
self.cuda_toolkit = cuda_toolkit
|
||||
self.arch = arch
|
||||
# Call parent with formatted error to avoid showing class name
|
||||
super().__init__("") # Empty string to avoid class name
|
||||
# Store formatted error for str() representation
|
||||
self._formatted_error = self._format_error()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Override string representation to avoid showing class name"""
|
||||
return self._formatted_error
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Override repr representation to avoid showing class name"""
|
||||
return self._formatted_error
|
||||
|
||||
def _format_error(self) -> str:
|
||||
if not self.nvvm_error:
|
||||
return str(self.args[0])
|
||||
|
||||
return f"""NVVM Compilation Error:
|
||||
----------------------
|
||||
|
||||
{self.BLUE}⚙️ Current Settings:{self.RESET}
|
||||
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
|
||||
- Target Architecture: {self.arch}{self.RESET}
|
||||
|
||||
IR Context (truncated):
|
||||
{self.ir_context}
|
||||
|
||||
{self.YELLOW}💡 Possible Solutions:{self.RESET}
|
||||
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
|
||||
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
|
||||
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
|
||||
|
||||
|
||||
class Compiler:
|
||||
"""Compiler class for compiling and building MLIR modules."""
|
||||
|
||||
def __init__(self, passmanager, execution_engine):
|
||||
self.passmanager = passmanager
|
||||
self.execution_engine = execution_engine
|
||||
|
||||
def __call__(self, module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
|
||||
def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Process error message to extract NVVM error and IR context"""
|
||||
nvvm_error = None
|
||||
ir_msg = ""
|
||||
|
||||
if "NVVM_ERROR" in error_msg:
|
||||
# Extract the specific NVVM error
|
||||
nvvm_error = (
|
||||
error_msg.split("libNVVM extra log:")[1].strip()
|
||||
if "libNVVM extra log:" in error_msg
|
||||
else error_msg
|
||||
)
|
||||
|
||||
# Extract IR context
|
||||
if "see current operation:" in error_msg:
|
||||
# Get the IR section
|
||||
ir_section = error_msg.split("see current operation:")[1].strip()
|
||||
# Remove duplicate IR section
|
||||
ir_section = ir_section.split("error: unknown: Failed translating")[
|
||||
0
|
||||
].strip()
|
||||
|
||||
# Get first few lines and last few lines of the IR
|
||||
ir_lines = ir_section.split("\n")
|
||||
if len(ir_lines) > 10:
|
||||
ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
|
||||
else:
|
||||
ir_msg = ir_section
|
||||
|
||||
return nvvm_error, ir_msg
|
||||
|
||||
def compile(
|
||||
self,
|
||||
module,
|
||||
pipeline: str,
|
||||
cuda_toolkit: str = "",
|
||||
arch: str = "",
|
||||
enable_verifier=False,
|
||||
):
|
||||
"""Compiles the module by invoking the pipeline."""
|
||||
try:
|
||||
pm = self.passmanager.PassManager.parse(pipeline)
|
||||
pm.enable_verifier(enable_verifier)
|
||||
pm.run(module.operation)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
nvvm_error, ir_msg = self._process_error(error_msg)
|
||||
|
||||
if nvvm_error:
|
||||
raise CompilationError(
|
||||
error_msg,
|
||||
nvvm_error=nvvm_error,
|
||||
ir_context=ir_msg,
|
||||
cuda_toolkit=cuda_toolkit,
|
||||
arch=arch,
|
||||
) from e
|
||||
raise e
|
||||
|
||||
def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return self.execution_engine.ExecutionEngine(
|
||||
module, opt_level=opt_level, shared_libs=shared_libs
|
||||
)
|
||||
|
||||
def compile_and_jit(
|
||||
self,
|
||||
module,
|
||||
pipeline: str,
|
||||
shared_libs: Sequence[str] = (),
|
||||
opt_level: int = 2,
|
||||
cuda_toolkit: str = "",
|
||||
arch: str = "",
|
||||
):
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(
|
||||
module,
|
||||
pipeline,
|
||||
cuda_toolkit,
|
||||
arch,
|
||||
)
|
||||
return self.jit(module, opt_level, shared_libs)
|
||||
|
||||
|
||||
def compile(func, *args, **kwargs):
|
||||
if func is None:
|
||||
raise DSLRuntimeError("Function is not set or invalid.")
|
||||
|
||||
if not callable(func):
|
||||
raise DSLRuntimeError("Object is not callable.")
|
||||
|
||||
kwargs["compile_only"] = True
|
||||
kwargs["no_cache"] = True
|
||||
|
||||
if inspect.isfunction(func):
|
||||
# regular function
|
||||
pass
|
||||
elif inspect.ismethod(func):
|
||||
# if it's a method, add the instance to the first argument
|
||||
args = [func.__self__] + list(args)
|
||||
func = func.__func__
|
||||
elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
|
||||
# If it's a class instance, get the class's __call__ method
|
||||
args = [func] + list(args)
|
||||
# Get the actual function from the class definition
|
||||
func = func.__call__.__func__
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Invalid function type, only function, method and module are supported, but got",
|
||||
func,
|
||||
)
|
||||
|
||||
# If it's a wrapped function created by jit decorator, get the original function
|
||||
if hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
1637
python/CuTeDSL/base_dsl/dsl.py
Normal file
1637
python/CuTeDSL/base_dsl/dsl.py
Normal file
File diff suppressed because it is too large
Load Diff
303
python/CuTeDSL/base_dsl/env_manager.py
Normal file
303
python/CuTeDSL/base_dsl/env_manager.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides utilities for the environment variables setup.
|
||||
|
||||
It provides an EnvironmentVarManager, which reads environment variables for the DSL
|
||||
and caches them for efficient access.
|
||||
|
||||
It also provides utilities to automatically setup a subset of environment variables
|
||||
based on heuristics.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from ..base_dsl.runtime.cuda import get_compute_capability_major_minor
|
||||
from .utils.logger import log
|
||||
|
||||
IS_WINDOWS = sys.platform == "win32"
|
||||
CLIB_EXT = ".dll" if IS_WINDOWS else ".so"
|
||||
|
||||
# =============================================================================
|
||||
# Environment Variable Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_str_env_var(var_name, default_value=None):
|
||||
value = os.getenv(var_name)
|
||||
return value if value is not None else default_value
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_bool_env_var(var_name, default_value=False):
|
||||
value = get_str_env_var(var_name)
|
||||
if value is None:
|
||||
return default_value
|
||||
return value not in {"False", "0", ""}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_int_env_var(var_name, default_value=0):
|
||||
value = get_str_env_var(var_name)
|
||||
return int(value) if value and value.isdigit() else default_value
|
||||
|
||||
|
||||
def detect_gpu_arch(prefix):
|
||||
"""
|
||||
Attempts to detect the machine's GPU architecture.
|
||||
|
||||
Returns:
|
||||
A string representing the GPU architecture (e.g. "70" for compute capability 7.0),
|
||||
or a default value(e.g. "sm_100") if the GPU architecture cannot be determined.
|
||||
"""
|
||||
arch = (None, None)
|
||||
try:
|
||||
arch = get_compute_capability_major_minor()
|
||||
except Exception as e:
|
||||
log().info(f"Failed to get CUDA compute capability: {e}")
|
||||
|
||||
if arch == (None, None):
|
||||
# default to sm_100
|
||||
arch = (10, 0)
|
||||
|
||||
major, minor = arch
|
||||
suffix = ""
|
||||
if major >= 9 and minor >= 0:
|
||||
suffix = "a"
|
||||
elif minor != 0:
|
||||
# e.g sm_86, belong with sm_80 family
|
||||
minor = 0
|
||||
return f"sm_{major}{minor}{suffix}"
|
||||
|
||||
|
||||
def find_libs_in_ancestors(start, target_libs, lib_folder_guesses):
|
||||
"""
|
||||
Search ancestor directories for a candidate library folder containing all required libraries.
|
||||
|
||||
Starting from the given path, this function traverses up through each parent directory.
|
||||
For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses)
|
||||
for files that match the required library extension (CLIB_EXT). Library file names are
|
||||
canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains
|
||||
all of the required libraries (as specified in target_libs), the function returns a list of
|
||||
absolute paths to these library files.
|
||||
|
||||
Parameters:
|
||||
start (str or Path): The starting directory from which to begin the search.
|
||||
target_libs (iterable of str): A collection of required library names (without the "lib" prefix).
|
||||
lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries.
|
||||
|
||||
Returns:
|
||||
list[str] or None: A list of resolved paths to the required library files if found; otherwise, None.
|
||||
"""
|
||||
# Traverse through all parent directories of the resolved starting path.
|
||||
for ancestor in Path(start).resolve().parents:
|
||||
# Iterate over each candidate relative directory path.
|
||||
for rel_path in lib_folder_guesses:
|
||||
target_dir = ancestor / rel_path
|
||||
# Skip if the candidate directory does not exist.
|
||||
if not target_dir.is_dir():
|
||||
continue
|
||||
|
||||
# Initialize a list to hold the resolved paths of matching library files.
|
||||
libs_cand = []
|
||||
# Create a set of the remaining libraries we need to find.
|
||||
remaining_libs = set(target_libs)
|
||||
|
||||
# Iterate over all items in the candidate directory.
|
||||
for p in target_dir.iterdir():
|
||||
# Consider only files with the expected library extension.
|
||||
if p.suffix == CLIB_EXT:
|
||||
# Canonicalize the library name by removing the "lib" prefix.
|
||||
lib_name = p.stem.removeprefix("lib")
|
||||
# If this library is required, add its resolved path and mark it as found.
|
||||
if lib_name in remaining_libs:
|
||||
libs_cand.append(str(p.resolve()))
|
||||
remaining_libs.remove(lib_name)
|
||||
|
||||
# If all required libraries have been found, return the list of library paths.
|
||||
if len(remaining_libs) == 0:
|
||||
return libs_cand
|
||||
|
||||
# Return None if no candidate directory contains all required libraries.
|
||||
return None
|
||||
|
||||
|
||||
def _find_cuda_home():
|
||||
"""Find the CUDA installation path using a series of heuristic methods.
|
||||
Methods below are checked in order, and the function returns on first match:
|
||||
1. Checking the environment variables CUDA_HOME and CUDA_PATH.
|
||||
2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda.
|
||||
3. Scanning common installation directories based on the operating system.
|
||||
- On Windows systems (when IS_WINDOWS is True), it searches in:
|
||||
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
|
||||
- On Unix-like systems, it searches in:
|
||||
/usr/local/cuda*
|
||||
|
||||
Returns:
|
||||
Optional[str]: The absolute CUDA installation path if found; otherwise, None.
|
||||
|
||||
Note:
|
||||
The variable IS_WINDOWS is defined in the module scope.
|
||||
"""
|
||||
# Guess #1
|
||||
cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH")
|
||||
if cuda_home is None:
|
||||
# Guess #2
|
||||
nvcc_path = shutil.which("nvcc")
|
||||
if nvcc_path is not None:
|
||||
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
|
||||
else:
|
||||
# Guess #3
|
||||
if IS_WINDOWS:
|
||||
glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
|
||||
else:
|
||||
glob_pat = "/usr/local/cuda*"
|
||||
cuda_homes = glob.glob(glob_pat)
|
||||
if len(cuda_homes) == 0:
|
||||
cuda_home = ""
|
||||
else:
|
||||
cuda_home = cuda_homes[0]
|
||||
if not os.path.exists(cuda_home):
|
||||
cuda_home = None
|
||||
return cuda_home
|
||||
|
||||
|
||||
def get_cuda_toolkit_path():
|
||||
"""
|
||||
Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if
|
||||
set. Otherwise, attempts to discover a valid CUDA toolkit location and
|
||||
return. If not found, return None.
|
||||
"""
|
||||
# Check if the environment variable is already set, if so, return it immediately.
|
||||
try:
|
||||
cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH")
|
||||
if cuda_toolkit_path_existing:
|
||||
return cuda_toolkit_path_existing
|
||||
|
||||
found_cuda_home = _find_cuda_home()
|
||||
if found_cuda_home:
|
||||
return found_cuda_home
|
||||
except Exception as e:
|
||||
log().info("default_env: exception on get_cuda_toolkit_path", e)
|
||||
return None
|
||||
|
||||
|
||||
def get_prefix_dsl_libs(prefix: str):
|
||||
"""
|
||||
Returns get_str_env_var('{prefix}_LIBS') if set.
|
||||
Otherwise, attempts to discover libs based on heuristics and return
|
||||
If not found, return None.
|
||||
"""
|
||||
# Check if the environment variable is already set, if so, return it immediately.
|
||||
try:
|
||||
prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS")
|
||||
if prefix_libs_existing:
|
||||
return prefix_libs_existing
|
||||
|
||||
def get_libs_cand(start):
|
||||
target_libs = {
|
||||
"mlir_c_runner_utils",
|
||||
"mlir_runner_utils",
|
||||
"mlir_cuda_runtime",
|
||||
}
|
||||
lib_folder_guesses = [
|
||||
"lib",
|
||||
]
|
||||
|
||||
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
|
||||
if libs_cand:
|
||||
dsl_libs = ":".join(libs_cand)
|
||||
return dsl_libs
|
||||
|
||||
return None
|
||||
|
||||
# find from install folder
|
||||
dsl_libs = get_libs_cand(__file__)
|
||||
|
||||
if not dsl_libs:
|
||||
# try to find from build folder structure
|
||||
dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve())
|
||||
|
||||
return dsl_libs
|
||||
|
||||
except Exception as e:
|
||||
log().info(f"default_env: exception on get_prefix_dsl_libs", e)
|
||||
return None
|
||||
|
||||
|
||||
class EnvironmentVarManager:
|
||||
"""Manages environment variables for configuration options.
|
||||
|
||||
Printing options:
|
||||
- [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False)
|
||||
- [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False)
|
||||
- [DSL_NAME]_PRINT_IR: Print generated IR (default: False)
|
||||
- [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
|
||||
File options:
|
||||
- [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
|
||||
- [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False)
|
||||
Other options:
|
||||
- [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1).
|
||||
- [DSL_NAME]_DRYRUN: Generates IR only (default: False)
|
||||
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
|
||||
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
|
||||
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
|
||||
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
|
||||
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
|
||||
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
|
||||
- [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
|
||||
- [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, prefix="DSL"):
|
||||
self.prefix = prefix # change if needed
|
||||
|
||||
# Printing options
|
||||
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
||||
self.print_after_preprocessor = get_bool_env_var(
|
||||
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
|
||||
)
|
||||
self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False)
|
||||
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
|
||||
# File options
|
||||
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
|
||||
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
|
||||
# Other options
|
||||
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
|
||||
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
|
||||
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
|
||||
self.warnings_as_errors = get_bool_env_var(
|
||||
f"{prefix}_WARNINGS_AS_ERRORS", False
|
||||
)
|
||||
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
|
||||
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
|
||||
self.disable_file_caching = get_bool_env_var(
|
||||
f"{prefix}_DISABLE_FILE_CACHING", False
|
||||
)
|
||||
self.file_caching_capacity = get_int_env_var(
|
||||
f"{prefix}_FILE_CACHING_CAPACITY", 1000
|
||||
)
|
||||
self.generate_source_location = not get_bool_env_var(
|
||||
f"{prefix}_NO_SOURCE_LOCATION", False
|
||||
)
|
||||
# set cuda
|
||||
self.cuda_toolkit = get_cuda_toolkit_path()
|
||||
|
||||
# set mlir shared libraries
|
||||
self.shared_libs = get_prefix_dsl_libs(prefix)
|
||||
301
python/CuTeDSL/base_dsl/jit_executor.py
Normal file
301
python/CuTeDSL/base_dsl/jit_executor.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides jit executor related classes
|
||||
"""
|
||||
import io
|
||||
import inspect
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from typing import get_origin
|
||||
|
||||
# Local modules imports
|
||||
from .utils.timer import timer
|
||||
from .utils.logger import log
|
||||
from .common import DSLRuntimeError
|
||||
from .runtime import cuda as cuda_helpers
|
||||
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
|
||||
from .typing import get_c_pointers
|
||||
from . import typing as t
|
||||
|
||||
# MLIR modules imports
|
||||
from .._mlir import ir
|
||||
|
||||
|
||||
class CudaSingleModule:
|
||||
def __init__(self, cuda_module, kernel_ptr):
|
||||
self.cuda_module = cuda_module
|
||||
self.kernel_ptr = kernel_ptr
|
||||
|
||||
|
||||
class CudaModules:
|
||||
def __init__(self, modules, args):
|
||||
# list of CudaSingleModule
|
||||
self.modules = modules
|
||||
# extra kernel ptr arguments for launch
|
||||
self.args = args
|
||||
|
||||
|
||||
class JitExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
dsl,
|
||||
engine,
|
||||
capi_func,
|
||||
ir_module,
|
||||
args_spec,
|
||||
function_name,
|
||||
cuda_modules: CudaModules = None,
|
||||
jit_time_profiling=False,
|
||||
):
|
||||
self.dsl = dsl
|
||||
self.engine = engine
|
||||
self.capi_func = capi_func
|
||||
self.ir_module = ir_module
|
||||
self.args_spec = args_spec
|
||||
self.function_name = function_name
|
||||
if args_spec is not None:
|
||||
self.args_spec = self.filter_runtime_arg_spec(args_spec)
|
||||
# cuda kernels
|
||||
self.cuda_modules = cuda_modules
|
||||
self.jit_time_profiling = jit_time_profiling
|
||||
|
||||
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
|
||||
runtime_args = []
|
||||
runtime_annotations = {}
|
||||
runtime_defaults = []
|
||||
|
||||
# Calculate the offset where defaults start in the original args
|
||||
if arg_spec.defaults:
|
||||
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
|
||||
else:
|
||||
defaults_start_idx = len(arg_spec.args)
|
||||
|
||||
# Filter arguments and maintain their properties
|
||||
for i, arg_name in enumerate(arg_spec.args):
|
||||
arg_type = arg_spec.annotations.get(arg_name, None)
|
||||
|
||||
# Skip compile-time arguments
|
||||
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
||||
continue
|
||||
|
||||
# Keep runtime arguments
|
||||
runtime_args.append(arg_name)
|
||||
if arg_name in arg_spec.annotations:
|
||||
runtime_annotations[arg_name] = arg_type
|
||||
|
||||
# Keep corresponding default if it exists
|
||||
if i >= defaults_start_idx:
|
||||
default_idx = i - defaults_start_idx
|
||||
runtime_defaults.append(arg_spec.defaults[default_idx])
|
||||
|
||||
# Filter kwonlyargs and their defaults
|
||||
runtime_kwonlyargs = []
|
||||
runtime_kwonlydefaults = {}
|
||||
|
||||
if arg_spec.kwonlyargs:
|
||||
for kwarg in arg_spec.kwonlyargs:
|
||||
arg_type = arg_spec.annotations.get(kwarg, None)
|
||||
|
||||
# Apply same filtering logic
|
||||
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
|
||||
continue
|
||||
|
||||
runtime_kwonlyargs.append(kwarg)
|
||||
if kwarg in arg_spec.annotations:
|
||||
runtime_annotations[kwarg] = arg_type
|
||||
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
|
||||
runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
|
||||
|
||||
# Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
|
||||
runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
|
||||
|
||||
return inspect.FullArgSpec(
|
||||
args=runtime_args,
|
||||
varargs=arg_spec.varargs, # Keep original varargs
|
||||
varkw=arg_spec.varkw, # Keep original varkw
|
||||
defaults=runtime_defaults,
|
||||
kwonlyargs=runtime_kwonlyargs,
|
||||
kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
|
||||
annotations=runtime_annotations,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.cuda_modules:
|
||||
cuda_modules = [module.cuda_module for module in self.cuda_modules.modules]
|
||||
for module in set(cuda_modules):
|
||||
cuda_helpers.unload_cubin_module(module)
|
||||
|
||||
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
|
||||
"""
|
||||
This function is the prune version of `generate_mlir_function_types` which only generates execution args
|
||||
to get rid of mlir context.
|
||||
"""
|
||||
|
||||
# args/kwargs must match arg_specs
|
||||
# No canonicalization of args/kwargs to avoid extra latency
|
||||
if len(args) != len(args_spec.args) or len(kwargs) != len(args_spec.kwonlyargs):
|
||||
raise DSLRuntimeError(
|
||||
"input args/kwargs length does not match runtime function signature!",
|
||||
context={
|
||||
"input args length": len(args),
|
||||
"input kwargs length": len(kwargs),
|
||||
"function signature args length": len(args_spec.args),
|
||||
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
|
||||
},
|
||||
)
|
||||
|
||||
exe_args = []
|
||||
input_args = [*args, *kwargs.values()]
|
||||
input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
|
||||
for i, arg in enumerate(input_args):
|
||||
arg_type = args_spec.annotations.get(input_arg_names[i], None)
|
||||
|
||||
# Implicit cast to NumericMeta
|
||||
if isinstance(arg_type, t.NumericMeta):
|
||||
arg = t.cast(arg, arg_type)
|
||||
|
||||
# If not any known type, try registered adapter to do the conversion
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
adapted_arg = adapter(arg) if adapter else arg
|
||||
exe_args.extend(get_c_pointers(adapted_arg))
|
||||
|
||||
return exe_args
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
exe_args = self.generate_execution_args(args, kwargs, self.args_spec)
|
||||
|
||||
self.run_compiled_program(exe_args)
|
||||
|
||||
# Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
|
||||
def get_invoke_packed_args(self, exe_args):
|
||||
if self.cuda_modules:
|
||||
exe_args += self.cuda_modules.args
|
||||
packed_args = (ctypes.c_void_p * len(exe_args))()
|
||||
for argNum in range(len(exe_args)):
|
||||
packed_args[argNum] = exe_args[argNum]
|
||||
return packed_args
|
||||
|
||||
def run_compiled_program(self, exe_args):
|
||||
if self.jit_time_profiling:
|
||||
profiler = timer(enable=True)
|
||||
try:
|
||||
packed_args = profiler(self.get_invoke_packed_args)(exe_args)
|
||||
profiler(self.capi_func)(packed_args)
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
||||
else:
|
||||
try:
|
||||
packed_args = self.get_invoke_packed_args(exe_args)
|
||||
self.capi_func(packed_args)
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
||||
|
||||
def update_jit_cuda_modules(self, kernel_symbols):
|
||||
# preload cuda module from compiled cubin in ir and store to jit_executor.kernels.
|
||||
if len(kernel_symbols) > 0:
|
||||
extra_args = []
|
||||
module = self.ir_module
|
||||
cuda_kernel_cache = dict()
|
||||
cuda_driver_version = cuda_helpers.get_driver_version()
|
||||
for sym in kernel_symbols:
|
||||
if sym not in cuda_kernel_cache:
|
||||
log().debug(f"Loading CUDA module for symbol: {sym}")
|
||||
|
||||
# load cuda module/get function pointer from module and cache
|
||||
def walk_callback(sym, func_sym, cubin_data):
|
||||
cubin_module = cuda_helpers.load_cubin_module_data(cubin_data)
|
||||
kernel_ptr = cuda_helpers.get_kernel_function(
|
||||
cubin_module, func_sym
|
||||
)
|
||||
# Enable non-portable cluster size for CUDA version 11.8 or higher.
|
||||
if cuda_driver_version >= 11080:
|
||||
cuda_helpers.set_kernel_attribute(
|
||||
kernel_ptr,
|
||||
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
|
||||
1,
|
||||
)
|
||||
cuda_kernel_cache[sym] = CudaSingleModule(
|
||||
cubin_module, kernel_ptr
|
||||
)
|
||||
|
||||
self.walk_module_and_get_cubin_data(module, sym, walk_callback)
|
||||
else:
|
||||
log().debug(f"Symbol {sym} already in cache")
|
||||
# check if kernel is empty.
|
||||
if sym in cuda_kernel_cache:
|
||||
extra_args.append(
|
||||
ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr())
|
||||
)
|
||||
# store to the jit result if jit result is cached.
|
||||
self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args)
|
||||
|
||||
return self
|
||||
|
||||
def _get_escaped_cubin_bytes(self, cubin_data):
|
||||
"""This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
|
||||
|
||||
def ishex(inp):
|
||||
return (
|
||||
inp in range(0x30, 0x3A)
|
||||
or inp in range(0x61, 0x67)
|
||||
or inp in range(0x41, 0x47)
|
||||
)
|
||||
|
||||
converted = bytearray()
|
||||
idx = 0
|
||||
while idx < len(cubin_data):
|
||||
# escape the original bytes
|
||||
if cubin_data[idx] == 0x5C:
|
||||
# if data of idx is b'\\'
|
||||
if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
|
||||
converted += bytearray.fromhex(
|
||||
cubin_data[idx + 1 : idx + 3].decode()
|
||||
)
|
||||
idx += 3
|
||||
elif cubin_data[idx + 1] == 0x5C:
|
||||
converted.append(cubin_data[idx])
|
||||
idx += 2
|
||||
else:
|
||||
# no escape, directly write
|
||||
converted.append(cubin_data[idx])
|
||||
idx += 1
|
||||
return bytes(converted)
|
||||
|
||||
def walk_module_and_get_cubin_data(self, module, sym, callback):
|
||||
"""This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
|
||||
|
||||
def walk_gpu_binary_op(op):
|
||||
if op.name != "gpu.binary":
|
||||
return ir.WalkResult.ADVANCE
|
||||
s = io.BytesIO()
|
||||
op.write_bytecode(s)
|
||||
cubin_data = s.getvalue()
|
||||
if sym.encode() not in cubin_data:
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
if (
|
||||
"kernels" != op.opview.sym_name.value
|
||||
and sym != op.opview.sym_name.value
|
||||
):
|
||||
return ir.WalkResult.ADVANCE
|
||||
# function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
|
||||
func_sym = sym
|
||||
if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
|
||||
func_sym = sym.rsplit("_", 1)[0]
|
||||
|
||||
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
|
||||
cubin_data = self._get_escaped_cubin_bytes(cubin_data)
|
||||
callback(sym, func_sym, cubin_data)
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
module.operation.walk(walk_gpu_binary_op)
|
||||
29
python/CuTeDSL/base_dsl/runtime/__init__.py
Normal file
29
python/CuTeDSL/base_dsl/runtime/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides a runtime utility functions that are needed for
|
||||
the DSL.
|
||||
"""
|
||||
|
||||
from . import device_tensor
|
||||
from . import dlpack_types
|
||||
from . import cuda
|
||||
from . import tensor_descriptor
|
||||
from . import jit_arg_adapters
|
||||
|
||||
__all__ = [
|
||||
"device_tensor",
|
||||
"dlpack_types",
|
||||
"cuda",
|
||||
"tensor_descriptor",
|
||||
"jit_arg_adapters",
|
||||
]
|
||||
470
python/CuTeDSL/base_dsl/runtime/cuda.py
Normal file
470
python/CuTeDSL/base_dsl/runtime/cuda.py
Normal file
@@ -0,0 +1,470 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides CUDA Python helper functions
|
||||
"""
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
|
||||
# MLIR imports
|
||||
from ..._mlir import ir
|
||||
from ..._mlir.dialects import gpu
|
||||
|
||||
# Local module imports
|
||||
from ..utils.logger import log as _log
|
||||
from ..common import *
|
||||
from .jit_arg_adapters import JitArgAdapterRegistry
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utils
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _cudaGetErrorEnum(error):
|
||||
if isinstance(error, cuda.CUresult):
|
||||
err, name = cuda.cuGetErrorName(error)
|
||||
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
|
||||
elif isinstance(error, nvrtc.nvrtcResult):
|
||||
return nvrtc.nvrtcGetErrorString(error)[1]
|
||||
else:
|
||||
raise DSLRuntimeError("Unknown error type: {}".format(error))
|
||||
|
||||
|
||||
def _get_gpu_arch_info(major, minor):
|
||||
"""Get GPU architecture information and compatibility details."""
|
||||
gpu_arch_map = {
|
||||
(7, 0): ("Volta", "sm_70", ["sm_70"]), # V100
|
||||
(7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX
|
||||
(8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100
|
||||
(8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series
|
||||
(8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series
|
||||
(8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40
|
||||
(9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100
|
||||
(10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200
|
||||
}
|
||||
return gpu_arch_map.get(
|
||||
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
|
||||
)
|
||||
|
||||
|
||||
def get_compute_capability_major_minor(device_id: int = 0):
|
||||
"""
|
||||
Returns the compute capability of the CUDA device as a tuple of (major, minor).
|
||||
For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell.
|
||||
Returns None on failure.
|
||||
"""
|
||||
try:
|
||||
checkCudaErrors(cuda.cuInit(0))
|
||||
device = checkCudaErrors(cuda.cuDeviceGet(device_id))
|
||||
major = checkCudaErrors(
|
||||
cuda.cuDeviceGetAttribute(
|
||||
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
|
||||
device,
|
||||
)
|
||||
)
|
||||
minor = checkCudaErrors(
|
||||
cuda.cuDeviceGetAttribute(
|
||||
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
|
||||
device,
|
||||
)
|
||||
)
|
||||
return major, minor
|
||||
except RuntimeError as e:
|
||||
_log().info(f"Failed to get CUDA compute capability: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceInfo:
|
||||
"""Data class to store CUDA device information."""
|
||||
|
||||
device_count: int = 0
|
||||
current_device: int = 0
|
||||
device_name: Optional[str] = None
|
||||
major_version: Optional[int] = None
|
||||
minor_version: Optional[int] = None
|
||||
arch_name: Optional[str] = None
|
||||
sm_arch: Optional[str] = None
|
||||
compatible_archs: Optional[List[str]] = None
|
||||
memory_gb: Optional[float] = None
|
||||
target_arch: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
initialization_failed: bool = False
|
||||
|
||||
def pretty_str(self) -> str:
|
||||
"""
|
||||
Convert DeviceInfo to a formatted string for display.
|
||||
"""
|
||||
info = ""
|
||||
|
||||
if self.initialization_failed:
|
||||
return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}"
|
||||
|
||||
if self.error_message:
|
||||
return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}"
|
||||
|
||||
if self.device_count > 0:
|
||||
info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n"
|
||||
|
||||
if self.major_version is not None and self.minor_version is not None:
|
||||
info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n"
|
||||
info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n"
|
||||
|
||||
if self.memory_gb is not None:
|
||||
info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n"
|
||||
|
||||
else:
|
||||
info += f"- Compute capability: unknown\n"
|
||||
info += f"- SM arch: unknown{Colors.RESET}\n"
|
||||
else:
|
||||
info += f"- No devices available\n"
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def get_device_info() -> DeviceInfo:
|
||||
"""
|
||||
Get detailed information about CUDA devices.
|
||||
Returns a DeviceInfo dataclass with device information.
|
||||
"""
|
||||
device_info = DeviceInfo()
|
||||
|
||||
# Initialize CUDA if not already initialized
|
||||
try:
|
||||
result = cuda.cuInit(0)
|
||||
if result[0].value: # Check for error
|
||||
device_info.initialization_failed = True
|
||||
return device_info
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get device count
|
||||
result = cuda.cuDeviceGetCount()
|
||||
device_info.device_count = result[1] if result[0].value == 0 else 0
|
||||
|
||||
if device_info.device_count > 0:
|
||||
# Get current device
|
||||
try:
|
||||
result = cuda.cuCtxGetDevice()
|
||||
if result[0].value == 0:
|
||||
device_info.current_device = result[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get device name
|
||||
try:
|
||||
name_result = cuda.cuDeviceGetName(100, device_info.current_device)
|
||||
if name_result[0].value == 0:
|
||||
device_info.device_name = name_result[1]
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get compute capability and architecture info
|
||||
try:
|
||||
major, minor = get_compute_capability_major_minor(
|
||||
device_info.current_device
|
||||
)
|
||||
|
||||
# Check if we successfully got the compute capability
|
||||
if major is not None and minor is not None:
|
||||
device_info.major_version = major
|
||||
device_info.minor_version = minor
|
||||
|
||||
arch_name, sm_arch, compatible_archs = _get_gpu_arch_info(
|
||||
device_info.major_version, device_info.minor_version
|
||||
)
|
||||
|
||||
device_info.arch_name = arch_name
|
||||
device_info.sm_arch = sm_arch
|
||||
device_info.compatible_archs = compatible_archs
|
||||
|
||||
# Get memory info
|
||||
try:
|
||||
total_mem = cuda.cuDeviceGetAttribute(
|
||||
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY,
|
||||
device_info.current_device,
|
||||
)
|
||||
if total_mem[0].value == 0:
|
||||
device_info.memory_gb = total_mem[1] / (
|
||||
1024 * 1024 * 1024
|
||||
) # Convert to GB
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
pass # Compute capability info will remain None
|
||||
|
||||
except Exception as e:
|
||||
device_info.error_message = str(e)
|
||||
|
||||
return device_info
|
||||
|
||||
|
||||
def checkCudaErrors(result):
|
||||
"""Check CUDA errors and provide detailed error messages."""
|
||||
if result[0].value:
|
||||
error_code = result[0].value
|
||||
error_name = _cudaGetErrorEnum(result[0])
|
||||
|
||||
raise DSLCudaRuntimeError(error_code, error_name)
|
||||
|
||||
if len(result) == 1:
|
||||
return None
|
||||
elif len(result) == 2:
|
||||
return result[1]
|
||||
else:
|
||||
return result[1:]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Driver Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_cuda_context(device_id: int = 0, flags: int = 0):
|
||||
"""
|
||||
Initializes the CUDA context for a specified device.
|
||||
"""
|
||||
# Initialize CUDA Driver API
|
||||
_log().info(f"cuInit {flags}")
|
||||
checkCudaErrors(cuda.cuInit(flags))
|
||||
# Retrieve handle for device
|
||||
_log().info(f"cuDeviceGet {device_id}")
|
||||
cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id))
|
||||
_log().info(f"{cuDevice} <-- cuDeviceGet")
|
||||
# Create context
|
||||
_log().info(f"cuCtxCreate {0} {cuDevice}")
|
||||
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
|
||||
_log().info(f"{context} <-- cuCtxCreate")
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def load_cubin_module(cubin_file):
|
||||
"""
|
||||
Loads a CUBIN file and returns the module.
|
||||
"""
|
||||
# Load CUBIN file as binary data
|
||||
_log().info(f"read cubin {cubin_file}")
|
||||
with open(cubin_file, "rb") as f:
|
||||
cubin_data = f.read()
|
||||
# Load module data
|
||||
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
|
||||
module = checkCudaErrors(
|
||||
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
def unload_cubin_module(module):
|
||||
"""
|
||||
Unloads a CUBIN module.
|
||||
"""
|
||||
_log().info(f"cuModuleUnload {module}")
|
||||
checkCudaErrors(cuda.cuModuleUnload(module))
|
||||
|
||||
|
||||
def load_cubin_module_data(cubin_data):
|
||||
"""
|
||||
Loads a CUBIN from data and returns the module.
|
||||
"""
|
||||
# Load module data
|
||||
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
|
||||
module = checkCudaErrors(
|
||||
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
def get_kernel_function(module, kernel_name):
|
||||
"""
|
||||
Retrieves the kernel function from the module.
|
||||
"""
|
||||
_log().info(f"cuModuleGetFunction {module} {kernel_name}")
|
||||
kernel = checkCudaErrors(
|
||||
cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8"))
|
||||
)
|
||||
_log().info(f"{kernel} <-- cuModuleGetFunction")
|
||||
return kernel
|
||||
|
||||
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
|
||||
"""
|
||||
Launches the CUDA kernel.
|
||||
"""
|
||||
_log().info(
|
||||
f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}"
|
||||
)
|
||||
checkCudaErrors(
|
||||
cuda.cuLaunchKernel(
|
||||
kernel,
|
||||
grid_dims[0],
|
||||
grid_dims[1],
|
||||
grid_dims[2],
|
||||
block_dims[0],
|
||||
block_dims[1],
|
||||
block_dims[2],
|
||||
smem_size, # Shared memory size
|
||||
stream,
|
||||
kernel_args,
|
||||
0, # Extra parameters
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def stream_sync(stream):
|
||||
"""
|
||||
Synchronizes the CUDA stream.
|
||||
"""
|
||||
_log().info(f"cuStreamSynchronize {stream}")
|
||||
checkCudaErrors(cuda.cuStreamSynchronize(stream))
|
||||
|
||||
|
||||
def stream_create(id=0):
|
||||
"""
|
||||
Creates the CUDA stream.
|
||||
"""
|
||||
_log().info(f"cuStreamCreate {id}")
|
||||
stream = checkCudaErrors(cuda.cuStreamCreate(id))
|
||||
_log().info(f"{stream} <-- cuStreamCreate")
|
||||
return stream
|
||||
|
||||
|
||||
def stream_destroy(stream):
|
||||
"""
|
||||
Destroys the CUDA stream.
|
||||
"""
|
||||
_log().info(f"cuStreamDestroy {stream}")
|
||||
checkCudaErrors(cuda.cuStreamDestroy(stream))
|
||||
|
||||
|
||||
def context_destroy(context):
|
||||
"""
|
||||
Destroys the CUDA context.
|
||||
"""
|
||||
_log().info(f"cuCtxDestroy {context}")
|
||||
checkCudaErrors(cuda.cuCtxDestroy(context))
|
||||
|
||||
|
||||
def allocate(size_in_bytes: int, stream=None):
|
||||
"""
|
||||
Allocate device memory based on numpy host array size.
|
||||
"""
|
||||
_log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream)
|
||||
if stream is None:
|
||||
device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
|
||||
else:
|
||||
device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream))
|
||||
_log().info("Allocated [%s]", device_memory)
|
||||
return device_memory
|
||||
|
||||
|
||||
def deallocate(device_pointer, stream=None):
|
||||
"""
|
||||
Deallocate the specified device memory pointer.
|
||||
"""
|
||||
_log().info(
|
||||
"Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream
|
||||
)
|
||||
if stream is None:
|
||||
checkCudaErrors(cuda.cuMemFree(device_pointer))
|
||||
else:
|
||||
checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream))
|
||||
|
||||
|
||||
def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None):
|
||||
"""
|
||||
Copy data from host to device memory.
|
||||
"""
|
||||
_log().info(
|
||||
"Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]",
|
||||
hex(host_pointer),
|
||||
hex(int(device_pointer)),
|
||||
size_in_bytes,
|
||||
stream,
|
||||
)
|
||||
if stream is None:
|
||||
checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes))
|
||||
else:
|
||||
checkCudaErrors(
|
||||
cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream)
|
||||
)
|
||||
|
||||
|
||||
def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None):
|
||||
"""
|
||||
Copy data from device to host memory.
|
||||
"""
|
||||
_log().info(
|
||||
"Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]",
|
||||
hex(int(device_pointer)),
|
||||
hex(host_pointer),
|
||||
size_in_bytes,
|
||||
stream,
|
||||
)
|
||||
if stream is None:
|
||||
checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes))
|
||||
else:
|
||||
checkCudaErrors(
|
||||
cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream)
|
||||
)
|
||||
|
||||
|
||||
def default_stream():
|
||||
return cuda.CUstream(0)
|
||||
|
||||
|
||||
def get_driver_version():
|
||||
"""
|
||||
Returns the CUDA driver version.
|
||||
"""
|
||||
return checkCudaErrors(cuda.cuDriverGetVersion())
|
||||
|
||||
|
||||
def set_kernel_attribute(kernel, attribute, value):
|
||||
"""
|
||||
Sets a CUDA kernel attribute.
|
||||
"""
|
||||
return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value))
|
||||
|
||||
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream)
|
||||
class StreamAdapter:
|
||||
"""
|
||||
Convert a CUDA stream to a stream representation for JIT arg generation.
|
||||
"""
|
||||
|
||||
def __init__(self, arg):
|
||||
self._arg = arg
|
||||
self._c_pointer = ctypes.cast(self._arg.getPtr(), ctypes.c_void_p)
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
return values[0]
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [self._c_pointer]
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [gpu.AsyncTokenType.get()]
|
||||
121
python/CuTeDSL/base_dsl/runtime/device_tensor.py
Normal file
121
python/CuTeDSL/base_dsl/runtime/device_tensor.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import copy
|
||||
|
||||
from . import cuda as cuda_helpers
|
||||
from .tensor_descriptor import *
|
||||
from ..common import *
|
||||
|
||||
|
||||
def allocate(tensor: TensorDescriptor, stream=None):
|
||||
"""
|
||||
Allocates GPU memory
|
||||
"""
|
||||
if tensor._check_is_managed_by_framework():
|
||||
raise DSLRuntimeError(
|
||||
"GPU tensors are managed by the framework and cannot be modified."
|
||||
)
|
||||
if not tensor.device_pointer is None:
|
||||
raise DSLRuntimeError("Tensor is already allocated on the device.")
|
||||
|
||||
tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream)
|
||||
|
||||
log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
||||
|
||||
|
||||
def deallocate(tensor: TensorDescriptor, stream=None):
|
||||
"""
|
||||
Deallocates GPU memory
|
||||
"""
|
||||
if tensor._check_is_managed_by_framework():
|
||||
raise DSLRuntimeError(
|
||||
"GPU tensors are managed by the framework and cannot be modified."
|
||||
)
|
||||
if tensor.device_pointer is None:
|
||||
raise DSLRuntimeError("Tensor is not allocated on the device.")
|
||||
|
||||
log().info(
|
||||
"Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer
|
||||
)
|
||||
|
||||
cuda_helpers.deallocate(tensor.device_pointer, stream)
|
||||
tensor.device_pointer = None
|
||||
|
||||
|
||||
def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None):
|
||||
"""
|
||||
Copies data from host memory to the GPU memory.
|
||||
If do_allocate is True, it first calls allocate
|
||||
"""
|
||||
log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
||||
if do_allocate:
|
||||
allocate(tensor, stream)
|
||||
cuda_helpers.memcpy_h2d(
|
||||
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
|
||||
)
|
||||
log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
||||
return tensor
|
||||
|
||||
|
||||
def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None):
|
||||
"""
|
||||
Copies data from GPU memory back to the host.
|
||||
If do_deallocate is True, it calls deallocate
|
||||
"""
|
||||
log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
||||
if tensor._check_is_managed_by_framework():
|
||||
raise DSLRuntimeError(
|
||||
"GPU tensors are managed by the framework and cannot be modified."
|
||||
)
|
||||
if tensor.device_pointer is None:
|
||||
raise DSLRuntimeError("Tensor is not allocated on the device.")
|
||||
|
||||
cuda_helpers.memcpy_d2h(
|
||||
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
|
||||
)
|
||||
if do_deallocate:
|
||||
deallocate(tensor, stream)
|
||||
log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
||||
|
||||
|
||||
def to_gpu(tensor, stream=None) -> TensorDescriptor:
|
||||
"""
|
||||
Copies the tensor to the GPU memory from Host memory
|
||||
"""
|
||||
if isinstance(tensor, TensorDescriptor):
|
||||
new_tensor = copy.copy(tensor)
|
||||
copy_to_gpu(new_tensor, stream=stream)
|
||||
return new_tensor
|
||||
|
||||
if TensorDescriptor.can_transformed_to_dlpack(tensor):
|
||||
new_tensor = TensorDescriptor(tensor)
|
||||
copy_to_gpu(new_tensor, stream=stream)
|
||||
return new_tensor
|
||||
|
||||
raise DSLRuntimeError("Unsupported type")
|
||||
|
||||
|
||||
def from_gpu(tensor, stream=None) -> TensorDescriptor:
|
||||
"""
|
||||
Copies the tensor to the GPU memory from Host memory
|
||||
"""
|
||||
if isinstance(tensor, TensorDescriptor):
|
||||
new_tensor = copy.copy(tensor)
|
||||
copy_from_gpu(new_tensor, stream=stream)
|
||||
return new_tensor
|
||||
|
||||
if TensorDescriptor.can_transformed_to_dlpack(tensor):
|
||||
new_tensor = TensorDescriptor(tensor)
|
||||
copy_from_gpu(new_tensor, stream=stream)
|
||||
return new_tensor
|
||||
|
||||
raise DSLRuntimeError("Unsupported type")
|
||||
76
python/CuTeDSL/base_dsl/runtime/dlpack_types.py
Normal file
76
python/CuTeDSL/base_dsl/runtime/dlpack_types.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides helper structs for dlpack.
|
||||
DLPack is an open standard for in-memory tensor structures, enabling
|
||||
seamless sharing of tensors across different frameworks.
|
||||
Learn more at: https://github.com/dmlc/dlpack
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import enum
|
||||
|
||||
|
||||
class DLDeviceType(enum.IntEnum):
|
||||
"""Enums for device types based on the DLPack specification."""
|
||||
|
||||
kDLCPU = 1
|
||||
kDLGPU = 2
|
||||
kDLCPUPinned = 3
|
||||
|
||||
|
||||
class DLDataTypeCode:
|
||||
"""Enums for data type codes based on the DLPack specification.
|
||||
|
||||
see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
|
||||
"""
|
||||
|
||||
kDLInt = 0
|
||||
kDLUInt = 1
|
||||
kDLFloat = 2
|
||||
kDLOpaqueHandle = 3
|
||||
kDLBfloat = 4
|
||||
kDLComplex = 5
|
||||
kDLBool = 6
|
||||
|
||||
|
||||
class DLDevice(ctypes.Structure):
|
||||
"""Structure representing the device information in DLPack."""
|
||||
|
||||
_fields_ = [
|
||||
("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc.
|
||||
("device_id", ctypes.c_int), # Device ID (e.g., GPU ID)
|
||||
]
|
||||
|
||||
|
||||
class DLDataType(ctypes.Structure):
|
||||
"""Structure representing the data type in DLPack."""
|
||||
|
||||
_fields_ = [
|
||||
("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat)
|
||||
("bits", ctypes.c_uint8), # Number of bits per value
|
||||
("lanes", ctypes.c_uint16), # Number of lanes
|
||||
]
|
||||
|
||||
|
||||
class DLTensor(ctypes.Structure):
|
||||
"""Structure representing the DLTensor in DLPack."""
|
||||
|
||||
_fields_ = [
|
||||
("data", ctypes.c_void_p), # Pointer to tensor data
|
||||
("device", DLDevice), # Device info
|
||||
("ndim", ctypes.c_int), # Number of dimensions
|
||||
("dtype", DLDataType), # Data type
|
||||
("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor
|
||||
("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor
|
||||
("byte_offset", ctypes.c_uint64), # Byte offset to tensor data
|
||||
]
|
||||
188
python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py
Normal file
188
python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides runtime utilities for JIT argument conversion in DSL.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import get_origin
|
||||
|
||||
# Local modules imports
|
||||
from ..common import DSLRuntimeError
|
||||
from ..typing import (
|
||||
Constexpr,
|
||||
Int32,
|
||||
Float32,
|
||||
Boolean,
|
||||
)
|
||||
|
||||
|
||||
def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
|
||||
"""
|
||||
Check if the argument spec is a constexpr.
|
||||
"""
|
||||
|
||||
def _is_reserved_python_func_arg(arg_index, arg_name, func):
|
||||
"""
|
||||
Check if the argument is a reserved python function argument.
|
||||
"""
|
||||
|
||||
if arg_index != 0:
|
||||
return False
|
||||
|
||||
if arg_name == "self":
|
||||
return True
|
||||
|
||||
is_classmethod = isinstance(func, classmethod) or (
|
||||
hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
|
||||
)
|
||||
return arg_name == "cls" and is_classmethod
|
||||
|
||||
return (
|
||||
_is_reserved_python_func_arg(arg_index, arg_name, owning_func)
|
||||
or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
|
||||
or (get_origin(arg_spec) is Constexpr)
|
||||
)
|
||||
|
||||
|
||||
def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
|
||||
"""
|
||||
Check if the argument is a constexpr.
|
||||
"""
|
||||
|
||||
def _is_type_argument(arg, arg_annotation):
|
||||
"""
|
||||
Check if the argument is a type argument like Type[X]
|
||||
"""
|
||||
|
||||
return isinstance(arg, type) and (
|
||||
arg_annotation is None or get_origin(arg_annotation) is type
|
||||
)
|
||||
|
||||
return (
|
||||
is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
|
||||
or _is_type_argument(arg, arg_spec)
|
||||
or arg is None
|
||||
)
|
||||
|
||||
|
||||
class JitArgAdapterRegistry:
|
||||
"""
|
||||
A registry to keep track of the JIT argument adapters.
|
||||
|
||||
An adapter is a callable that converts a Python type to a type with following protocols supported:
|
||||
- JitArgument
|
||||
- DynamicExpression
|
||||
The converted type can then be further processed by DSL to generate arguments for JIT functions.
|
||||
"""
|
||||
|
||||
# A dictionary with key=type and value=callable
|
||||
jit_arg_adapter_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register_jit_arg_adapter(cls, *dargs, **dkwargs):
|
||||
"""
|
||||
Register a JIT argument adapter callable
|
||||
|
||||
This can be used as a decorator on any callable like:
|
||||
|
||||
@register_jit_arg_adapter(my_py_type)
|
||||
def my_adapter_for_my_py_type(arg):
|
||||
...
|
||||
|
||||
@register_jit_arg_adapter(my_py_type)
|
||||
class MyAdapterForMyPythonType:
|
||||
...
|
||||
|
||||
The adapters are registered per type. If a type is already registerd, an error will be raised.
|
||||
"""
|
||||
|
||||
def decorator(*dargs, **dkwargs):
|
||||
darg_python_ty = dargs[0]
|
||||
|
||||
@wraps(darg_python_ty)
|
||||
def wrapper(*args, **kwargs):
|
||||
if len(args) != 1 or not callable(args[0]):
|
||||
raise DSLRuntimeError(
|
||||
"a callable must be provided for registering JIT argument adapter"
|
||||
)
|
||||
adapter = args[0]
|
||||
|
||||
if darg_python_ty in cls.jit_arg_adapter_registry:
|
||||
raise DSLRuntimeError(
|
||||
f"JIT argument adapter for {darg_python_ty} is already registered!",
|
||||
context={
|
||||
"Registered adapter": cls.jit_arg_adapter_registry[
|
||||
darg_python_ty
|
||||
],
|
||||
"Adapter to be registered": adapter,
|
||||
},
|
||||
)
|
||||
cls.jit_arg_adapter_registry[darg_python_ty] = adapter
|
||||
return adapter
|
||||
|
||||
return wrapper
|
||||
|
||||
if len(dargs) > 0:
|
||||
return decorator(*dargs, **dkwargs)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"a Python type must be provided for registering JIT argument adapter"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_registered_adapter(cls, ty):
|
||||
"""
|
||||
Get the registered JIT argument adapter for the given type.
|
||||
"""
|
||||
return cls.jit_arg_adapter_registry.get(ty, None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JIT Argument Adapters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(int)
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(float)
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(bool)
|
||||
def _convert_python_scalar(arg):
|
||||
"""
|
||||
Convert a Python scalar to a DSL type.
|
||||
"""
|
||||
conversion_map = {
|
||||
int: Int32,
|
||||
float: Float32,
|
||||
bool: Boolean,
|
||||
}
|
||||
return conversion_map.get(type(arg))(arg)
|
||||
|
||||
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
|
||||
@JitArgAdapterRegistry.register_jit_arg_adapter(list)
|
||||
def _convert_python_sequence(arg):
|
||||
"""
|
||||
Go through each element in the sequence and convert it to a type that can be
|
||||
further processed by DSL to generate the corresponding JIT argument(s).
|
||||
"""
|
||||
adapted_arg = []
|
||||
for elem in arg:
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
|
||||
if adapter is not None:
|
||||
converted_elem = adapter(elem)
|
||||
adapted_arg.append(converted_elem)
|
||||
else:
|
||||
# If no registered adapter is found, just return the original element
|
||||
adapted_arg.append(elem)
|
||||
|
||||
assert len(adapted_arg) == len(arg)
|
||||
return type(arg)(adapted_arg)
|
||||
201
python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py
Normal file
201
python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
# Helpers
|
||||
import itertools, operator
|
||||
import ctypes
|
||||
from . import dlpack_types as _dpack
|
||||
from .dlpack_runtime import (
|
||||
dlpack_to_tensor_desc,
|
||||
get_tensor_desc_data_ptr,
|
||||
get_tensor_desc_is_in_device,
|
||||
get_tensor_desc_element_type,
|
||||
get_tensor_desc_shape,
|
||||
get_tensor_desc_stride,
|
||||
get_tensor_desc_element_size_in_bytes,
|
||||
get_tensor_desc_ndim,
|
||||
get_tensor_desc_dtype_code,
|
||||
get_tensor_desc_dtype_bits,
|
||||
get_tensor_desc_device_type,
|
||||
get_tensor_desc_device_id,
|
||||
)
|
||||
|
||||
from ..utils.logger import log
|
||||
from ..common import *
|
||||
from ..typing import (
|
||||
Boolean,
|
||||
Float8E5M2,
|
||||
Int64,
|
||||
Int32,
|
||||
Int16,
|
||||
Int8,
|
||||
Uint64,
|
||||
Uint32,
|
||||
Uint16,
|
||||
Uint8,
|
||||
Float64,
|
||||
Float32,
|
||||
Float16,
|
||||
BFloat16,
|
||||
)
|
||||
|
||||
|
||||
class TensorDescriptor:
|
||||
def __init__(self, tensor):
|
||||
"""Initialize with a tensor that supports the DLPack protocol.
|
||||
|
||||
Args:
|
||||
tensor: Any tensor object that implements __dlpack__ and __dlpack_device__
|
||||
"""
|
||||
|
||||
self.tensor = tensor
|
||||
self._capsule = dlpack_to_tensor_desc(tensor)
|
||||
|
||||
self.data_ptr = get_tensor_desc_data_ptr(self._capsule)
|
||||
self.device_type = get_tensor_desc_device_type(self._capsule)
|
||||
self.device_type = _dpack.DLDeviceType(self.device_type)
|
||||
|
||||
if self.device_type == _dpack.DLDeviceType.kDLGPU:
|
||||
self.device_pointer = self.data_ptr
|
||||
elif self.device_type == _dpack.DLDeviceType.kDLCPU:
|
||||
self.device_pointer = None
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"DLPack device type is not supported {self.dl_tensor.device.device_type}"
|
||||
)
|
||||
|
||||
log().info("TensorDescriptor is created = [%s]", self)
|
||||
|
||||
@staticmethod
|
||||
def can_transformed_to_dlpack(dl_tensor):
|
||||
if not hasattr(dl_tensor, "__dlpack__") or not hasattr(
|
||||
dl_tensor, "__dlpack_device__"
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_in_device(self):
|
||||
"""Check if the tensor is stored on a device."""
|
||||
return not self.device_pointer is None
|
||||
|
||||
@property
|
||||
def device_id(self):
|
||||
"""Return device id where tensor resides."""
|
||||
if self.is_in_device:
|
||||
return get_tensor_desc_device_id(self._capsule)
|
||||
return -1
|
||||
|
||||
@property
|
||||
def element_type(self):
|
||||
"""Return the corresponding Python type based on DLPack dtype metadata."""
|
||||
str_element_type = get_tensor_desc_element_type(self._capsule)
|
||||
dtype_map = {
|
||||
# bool is 8bit from numpy and torch
|
||||
"Bool": Boolean,
|
||||
"Int64": Int64,
|
||||
"Int32": Int32,
|
||||
"Int16": Int16,
|
||||
"Int8": Int8,
|
||||
"UInt64": Uint64,
|
||||
"UInt32": Uint32,
|
||||
"UInt16": Uint16,
|
||||
"UInt8": Uint8,
|
||||
"Float64": Float64,
|
||||
"Float32": Float32,
|
||||
"Float16": Float16,
|
||||
"BFloat16": BFloat16,
|
||||
"Float8E5M2": Float8E5M2,
|
||||
}
|
||||
|
||||
if str_element_type not in dtype_map:
|
||||
raise KeyError(
|
||||
f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}"
|
||||
)
|
||||
|
||||
return dtype_map[str_element_type]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return the shape of the tensor."""
|
||||
return get_tensor_desc_shape(self._capsule)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""Return the rank of the tensor."""
|
||||
return get_tensor_desc_ndim(self._capsule)
|
||||
|
||||
@property
|
||||
def strides(self):
|
||||
"""Return the rank of the tensor."""
|
||||
return get_tensor_desc_stride(self._capsule)
|
||||
|
||||
@property
|
||||
def element_size_in_bytes(self):
|
||||
"""Calculate the element size in bytes of the DLPack tensor."""
|
||||
return get_tensor_desc_element_size_in_bytes(self._capsule)
|
||||
|
||||
@property
|
||||
def size_in_bytes(self):
|
||||
"""Calculate the total size in bytes of the DLPack tensor."""
|
||||
# Calculate the number of elements using the shape
|
||||
ndim = get_tensor_desc_ndim(self._capsule)
|
||||
shape = get_tensor_desc_shape(self._capsule)
|
||||
num_elements = 1
|
||||
for i in range(ndim):
|
||||
num_elements *= shape[i]
|
||||
|
||||
# Total bytes
|
||||
total_bytes = self.element_size_in_bytes * num_elements
|
||||
return total_bytes
|
||||
|
||||
def __str__(self):
|
||||
"""Return a compact string representation of the device_tensor with a tensor prefix."""
|
||||
# Extract shape
|
||||
shape = "x".join(map(str, self.shape))
|
||||
|
||||
# Extract dtype
|
||||
dtype_code = get_tensor_desc_dtype_code(self._capsule)
|
||||
dtype_bits = get_tensor_desc_dtype_bits(self._capsule)
|
||||
dtype = (
|
||||
f"i{dtype_bits}"
|
||||
if dtype_code == _dpack.DLDataTypeCode.kDLInt
|
||||
else f"f{dtype_bits}"
|
||||
)
|
||||
|
||||
# Extract device
|
||||
device_type = "cpu" if not self.is_in_device else "gpu"
|
||||
|
||||
return f"tensor<{shape}x{dtype}>_{device_type}"
|
||||
|
||||
def _check_is_managed_by_framework(self):
|
||||
"""
|
||||
Ensure the tensor is not managed by the framework (e.g., GPU tensor).
|
||||
Raises an exception if the tensor is framework-managed.
|
||||
"""
|
||||
return self.device_type == _dpack.DLDeviceType.kDLGPU
|
||||
|
||||
|
||||
def from_tensor(tensor) -> TensorDescriptor:
|
||||
"""Create a TensorDescriptor from a tensor object."""
|
||||
return TensorDescriptor(tensor)
|
||||
|
||||
|
||||
def to_tensor(tensor_descriptor: TensorDescriptor):
|
||||
"""Return tensor object from tensor descriptor."""
|
||||
return tensor_descriptor.tensor
|
||||
|
||||
|
||||
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
1897
python/CuTeDSL/base_dsl/typing.py
Normal file
1897
python/CuTeDSL/base_dsl/typing.py
Normal file
File diff suppressed because it is too large
Load Diff
19
python/CuTeDSL/base_dsl/utils/__init__.py
Normal file
19
python/CuTeDSL/base_dsl/utils/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from . import stacktrace
|
||||
from . import logger
|
||||
from . import timer
|
||||
__all__ = [
|
||||
"logger",
|
||||
"timer",
|
||||
"stacktrace",
|
||||
]
|
||||
80
python/CuTeDSL/base_dsl/utils/logger.py
Normal file
80
python/CuTeDSL/base_dsl/utils/logger.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides logging helper functions
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = None
|
||||
|
||||
|
||||
def log():
|
||||
return logger
|
||||
|
||||
|
||||
def setup_log(
|
||||
name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1
|
||||
):
|
||||
"""Set up and configure a logger with console and/or file handlers.
|
||||
|
||||
:param name: Name of the logger to create
|
||||
:type name: str
|
||||
:param log_to_console: Whether to enable logging to console, defaults to False
|
||||
:type log_to_console: bool, optional
|
||||
:param log_to_file: Whether to enable logging to file, defaults to False
|
||||
:type log_to_file: bool, optional
|
||||
:param log_file_path: Path to the log file, required if log_to_file is True
|
||||
:type log_file_path: str, optional
|
||||
:param log_level: Logging level to set, defaults to 1
|
||||
:type log_level: int, optional
|
||||
:raises ValueError: If log_to_file is True but log_file_path is not provided
|
||||
:return: Configured logger instance
|
||||
:rtype: logging.Logger
|
||||
"""
|
||||
# Create a custom logger
|
||||
global logger
|
||||
logger = logging.getLogger(name)
|
||||
if log_to_console or log_to_file:
|
||||
logger.setLevel(log_level)
|
||||
else:
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
# Clear existing handlers to prevent duplicate logs
|
||||
if logger.hasHandlers():
|
||||
logger.handlers.clear()
|
||||
|
||||
# Define formatter
|
||||
formatter = logging.Formatter(
|
||||
f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s"
|
||||
)
|
||||
|
||||
# Add console handler if enabled
|
||||
if log_to_console:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# Add file handler if enabled
|
||||
if log_to_file:
|
||||
if not log_file_path:
|
||||
raise ValueError("log_file_path must be provided when enable_file is True")
|
||||
file_handler = logging.FileHandler(log_file_path)
|
||||
file_handler.setLevel(log_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = setup_log("generic")
|
||||
165
python/CuTeDSL/base_dsl/utils/stacktrace.py
Normal file
165
python/CuTeDSL/base_dsl/utils/stacktrace.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides stacktrace helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def walk_to_top_module(start_path):
|
||||
"""
|
||||
Walk up from the start_path to find the top-level Python module.
|
||||
|
||||
:param start_path: The path to start from.
|
||||
:return: The path of the top-level module.
|
||||
"""
|
||||
current_path = start_path
|
||||
|
||||
while True:
|
||||
# Check if we are at the root directory
|
||||
if os.path.dirname(current_path) == current_path:
|
||||
break
|
||||
|
||||
# Check for __init__.py
|
||||
init_file_path = os.path.join(current_path, "__init__.py")
|
||||
if os.path.isfile(init_file_path):
|
||||
# If __init__.py exists, move up one level
|
||||
current_path = os.path.dirname(current_path)
|
||||
else:
|
||||
# If no __init__.py, we are not in a module; stop
|
||||
break
|
||||
|
||||
# If we reached the root without finding a module, return None
|
||||
if os.path.dirname(current_path) == current_path and not os.path.isfile(
|
||||
os.path.join(current_path, "__init__.py")
|
||||
):
|
||||
return None
|
||||
|
||||
# Return the path of the top-level module
|
||||
return current_path
|
||||
|
||||
|
||||
def _filter_internal_frames(traceback, internal_path):
|
||||
"""
|
||||
Filter out stack frames from the traceback that belong to the specified module path.
|
||||
|
||||
This function removes stack frames from the traceback whose file paths start with
|
||||
the given prefix_path, effectively hiding internal implementation details from
|
||||
the error traceback shown to users.
|
||||
"""
|
||||
iter_prev = None
|
||||
iter_tb = traceback
|
||||
while iter_tb is not None:
|
||||
if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith(
|
||||
internal_path
|
||||
):
|
||||
if iter_tb.tb_next:
|
||||
if iter_prev:
|
||||
iter_prev.tb_next = iter_tb.tb_next
|
||||
else:
|
||||
traceback = iter_tb.tb_next
|
||||
else:
|
||||
iter_prev = iter_tb
|
||||
iter_tb = iter_tb.tb_next
|
||||
return traceback
|
||||
|
||||
|
||||
_generated_function_names = re.compile(
|
||||
r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$"
|
||||
)
|
||||
|
||||
|
||||
def _filter_duplicated_frames(traceback):
|
||||
"""
|
||||
Filter out duplicated stack frames from the traceback.
|
||||
The function filters out consecutive frames that are in the same file and have the same line number.
|
||||
In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame.
|
||||
"""
|
||||
iter_prev = None
|
||||
iter_tb = traceback
|
||||
while iter_tb is not None:
|
||||
skip_current = False
|
||||
skip_next = False
|
||||
if iter_tb.tb_next:
|
||||
current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename)
|
||||
next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename)
|
||||
# if in the same file, check if the line number is the same
|
||||
if current_filename == next_filename:
|
||||
current_lineno = iter_tb.tb_lineno
|
||||
next_lineno = iter_tb.tb_next.tb_lineno
|
||||
if current_lineno == next_lineno:
|
||||
# Same file and line number, check name, if current is generated, skip current, otherwise skip next
|
||||
name = iter_tb.tb_frame.f_code.co_name
|
||||
is_generated = bool(_generated_function_names.match(name))
|
||||
if is_generated:
|
||||
# Skip current
|
||||
skip_current = True
|
||||
else:
|
||||
# Skip next if it's generated, otherwise keep both
|
||||
next_name = iter_tb.tb_next.tb_frame.f_code.co_name
|
||||
skip_next = bool(_generated_function_names.match(next_name))
|
||||
if skip_current:
|
||||
if iter_prev:
|
||||
iter_prev.tb_next = iter_tb.tb_next
|
||||
else:
|
||||
traceback = iter_tb.tb_next
|
||||
elif skip_next:
|
||||
# if next is last frame, don't skip
|
||||
if iter_tb.tb_next.tb_next:
|
||||
iter_tb.tb_next = iter_tb.tb_next.tb_next
|
||||
iter_prev = iter_tb
|
||||
else:
|
||||
iter_prev = iter_tb
|
||||
iter_tb = iter_tb.tb_next
|
||||
|
||||
return traceback
|
||||
|
||||
|
||||
def filter_stackframe(traceback, prefix_path):
|
||||
"""
|
||||
Filter out stack frames from the traceback that belong to the specified module path.
|
||||
|
||||
This function removes stack frames from the traceback whose file paths start with
|
||||
the given prefix_path, effectively hiding internal implementation details from
|
||||
the error traceback shown to users.
|
||||
|
||||
:param traceback: The traceback object to filter.
|
||||
:param prefix_path: The path prefix to filter out from the traceback.
|
||||
:return: The filtered traceback with internal frames removed.
|
||||
"""
|
||||
# Step 1: filter internal frames
|
||||
traceback = _filter_internal_frames(traceback, prefix_path)
|
||||
|
||||
# Step 2: consolidate duplicated frames
|
||||
return _filter_duplicated_frames(traceback)
|
||||
|
||||
|
||||
def filter_exception(value, module_dir):
|
||||
"""
|
||||
Filter out internal implementation details from exception traceback.
|
||||
|
||||
This function recursively processes an exception and its cause chain,
|
||||
removing stack frames that belong to the specified module directory.
|
||||
This helps to present cleaner error messages to users by hiding
|
||||
implementation details.
|
||||
|
||||
:param value: The exception object to filter.
|
||||
:param module_dir: The module directory path to filter out from tracebacks.
|
||||
:return: The filtered exception with internal frames removed.
|
||||
"""
|
||||
if hasattr(value, "__cause__") and value.__cause__:
|
||||
filter_exception(value.__cause__, module_dir)
|
||||
|
||||
if hasattr(value, "__traceback__"):
|
||||
filter_stackframe(value.__traceback__, module_dir)
|
||||
56
python/CuTeDSL/base_dsl/utils/timer.py
Normal file
56
python/CuTeDSL/base_dsl/utils/timer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides a timing helper functions
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
from .logger import log
|
||||
|
||||
|
||||
# TODO: revisit this part when mlir timing manager is ready for pybind.
|
||||
def timer(*dargs, **kwargs):
|
||||
enable = kwargs.get("enable", True)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def func_wrapper(*args, **kwargs):
|
||||
if not enable:
|
||||
return func(*args, **kwargs)
|
||||
from time import time
|
||||
|
||||
start = time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time()
|
||||
|
||||
# Convert time from seconds to us
|
||||
spend_us = (end - start) * 1e6
|
||||
|
||||
# Determine the function type and format the log message
|
||||
if hasattr(func, "__name__"):
|
||||
func_name = func.__name__
|
||||
log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs"
|
||||
elif "CFunctionType" in str(type(func)):
|
||||
log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs"
|
||||
else:
|
||||
log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs"
|
||||
|
||||
log().info(log_message)
|
||||
|
||||
return result
|
||||
|
||||
return func_wrapper
|
||||
|
||||
if len(dargs) == 1 and callable(dargs[0]):
|
||||
return decorator(dargs[0])
|
||||
else:
|
||||
return decorator
|
||||
57
python/CuTeDSL/cutlass/__init__.py
Normal file
57
python/CuTeDSL/cutlass/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .cutlass_dsl import (
|
||||
Constexpr,
|
||||
as_numeric,
|
||||
min,
|
||||
max,
|
||||
and_,
|
||||
or_,
|
||||
all_,
|
||||
any_,
|
||||
not_,
|
||||
all_,
|
||||
any_,
|
||||
select_,
|
||||
# Control-flow without AST pre-processor
|
||||
if_generate,
|
||||
for_generate,
|
||||
LoopUnroll,
|
||||
while_generate,
|
||||
yield_out,
|
||||
# Control-flow with AST pre-processor
|
||||
range_constexpr,
|
||||
range_dynamic,
|
||||
const_expr,
|
||||
dynamic_expr,
|
||||
# Data types
|
||||
dtype, # Provides conversions to types inheriting from NumericType
|
||||
DSLRuntimeError,
|
||||
JitArgAdapterRegistry,
|
||||
# Construction utilities for user-defined classes
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
)
|
||||
|
||||
from .cute.typing import *
|
||||
|
||||
# Utilities not belonging to CuTe
|
||||
from . import utils as utils
|
||||
|
||||
# Used as internal symbol
|
||||
from . import cutlass_dsl as _dsl
|
||||
|
||||
# Aliases
|
||||
LaunchConfig = _dsl.BaseDSL.LaunchConfig
|
||||
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
||||
gpu = _dsl.cutlass_gpu
|
||||
cuda = _dsl.cuda_helpers
|
||||
310
python/CuTeDSL/cutlass/cute/__init__.py
Normal file
310
python/CuTeDSL/cutlass/cute/__init__.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
# Use the auto-generated enum AddressSpace
|
||||
from cutlass._mlir.dialects.cute import AddressSpace
|
||||
|
||||
# Explicitly import types that might be directly used by other modules.
|
||||
# This is a fix for using Sphinx to generate documentation
|
||||
# Because Sphinx processes each module in isolation, it won't be able to rely
|
||||
# on re-exported symbols via wildcard imports (from .typing import *) in the
|
||||
# same way that Python does at runtime.
|
||||
from .typing import (
|
||||
Shape,
|
||||
Stride,
|
||||
IntTuple,
|
||||
Coord,
|
||||
Tile,
|
||||
XTuple,
|
||||
Tiler,
|
||||
Layout,
|
||||
Pointer,
|
||||
Tensor,
|
||||
)
|
||||
|
||||
# Import everything else
|
||||
from .typing import *
|
||||
|
||||
from .core import (
|
||||
assume,
|
||||
is_integer,
|
||||
is_int_tuple,
|
||||
is_static,
|
||||
size,
|
||||
has_underscore,
|
||||
slice_,
|
||||
make_ptr,
|
||||
make_layout,
|
||||
recast_layout,
|
||||
make_fragment_like,
|
||||
depth,
|
||||
rank,
|
||||
flatten_to_tuple,
|
||||
flatten,
|
||||
unflatten,
|
||||
product,
|
||||
product_like,
|
||||
shape,
|
||||
size_in_bytes,
|
||||
make_identity_layout,
|
||||
make_ordered_layout,
|
||||
make_composed_layout,
|
||||
make_layout_tv,
|
||||
make_swizzle,
|
||||
recast_ptr,
|
||||
make_tensor,
|
||||
make_identity_tensor,
|
||||
make_fragment,
|
||||
recast_tensor,
|
||||
get,
|
||||
select,
|
||||
front,
|
||||
is_major,
|
||||
find,
|
||||
coalesce,
|
||||
group_modes,
|
||||
cosize,
|
||||
dice,
|
||||
product_each,
|
||||
prepend,
|
||||
append,
|
||||
prepend_ones,
|
||||
append_ones,
|
||||
ceil_div,
|
||||
slice_and_offset,
|
||||
crd2idx,
|
||||
domain_offset,
|
||||
elem_less,
|
||||
transform_leaf,
|
||||
filter_zeros,
|
||||
filter,
|
||||
tile_to_shape,
|
||||
shape_div,
|
||||
composition,
|
||||
complement,
|
||||
right_inverse,
|
||||
left_inverse,
|
||||
max_common_layout,
|
||||
max_common_vector,
|
||||
logical_product,
|
||||
zipped_product,
|
||||
tiled_product,
|
||||
flat_product,
|
||||
raked_product,
|
||||
blocked_product,
|
||||
flat_divide,
|
||||
logical_divide,
|
||||
zipped_divide,
|
||||
tiled_divide,
|
||||
local_partition,
|
||||
local_tile,
|
||||
printf,
|
||||
print_tensor,
|
||||
# tiled mma/tiled copy
|
||||
make_mma_atom,
|
||||
make_tiled_mma,
|
||||
make_copy_atom,
|
||||
make_tiled_copy_tv,
|
||||
make_tiled_copy,
|
||||
make_tiled_copy_S,
|
||||
make_tiled_copy_D,
|
||||
make_tiled_copy_C_atom,
|
||||
basic_copy,
|
||||
basic_copy_if,
|
||||
autovec_copy,
|
||||
copy,
|
||||
gemm,
|
||||
# Wrapper classes
|
||||
ComposedLayout,
|
||||
Swizzle,
|
||||
E,
|
||||
Atom,
|
||||
MmaAtom,
|
||||
CopyAtom,
|
||||
TiledCopy,
|
||||
TiledMma,
|
||||
TensorSSA,
|
||||
ReductionOp,
|
||||
full,
|
||||
full_like,
|
||||
empty_like,
|
||||
ones_like,
|
||||
zeros_like,
|
||||
where,
|
||||
any_,
|
||||
all_,
|
||||
# User defined struct
|
||||
struct,
|
||||
pretty_str,
|
||||
make_layout_image_mask,
|
||||
repeat_like,
|
||||
round_up,
|
||||
is_congruent,
|
||||
is_weakly_congruent,
|
||||
ScaledBasis,
|
||||
get_divisibility,
|
||||
Ratio,
|
||||
)
|
||||
|
||||
from . import arch
|
||||
from . import nvgpu
|
||||
from . import testing
|
||||
from . import runtime
|
||||
|
||||
# Export all math ops without "math."
|
||||
from .math import *
|
||||
|
||||
# Used as internal symbol
|
||||
from .. import cutlass_dsl as _dsl
|
||||
|
||||
# Aliases
|
||||
jit = _dsl.CuTeDSL.jit
|
||||
kernel = _dsl.CuTeDSL.kernel
|
||||
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
||||
compile = _dsl.compile
|
||||
|
||||
# Explicitly export all symbols for documentation generation
|
||||
__all__ = [
|
||||
# Core types
|
||||
"AddressSpace",
|
||||
"Tensor",
|
||||
"Layout",
|
||||
"ComposedLayout",
|
||||
"Swizzle",
|
||||
"E",
|
||||
"Atom",
|
||||
"MmaAtom",
|
||||
"CopyAtom",
|
||||
"TiledCopy",
|
||||
"TiledMma",
|
||||
"TensorSSA",
|
||||
# Basic utility functions
|
||||
"assume",
|
||||
"is_integer",
|
||||
"is_int_tuple",
|
||||
"is_static",
|
||||
"size",
|
||||
"has_underscore",
|
||||
"slice_",
|
||||
"depth",
|
||||
"rank",
|
||||
"shape",
|
||||
"printf",
|
||||
"print_tensor",
|
||||
"pretty_str",
|
||||
# Layout functions
|
||||
"make_layout",
|
||||
"recast_layout",
|
||||
"make_identity_layout",
|
||||
"make_ordered_layout",
|
||||
"make_composed_layout",
|
||||
"make_layout_tv",
|
||||
"make_layout_image_mask",
|
||||
# Tensor functions
|
||||
"make_ptr",
|
||||
"make_tensor",
|
||||
"make_identity_tensor",
|
||||
"make_fragment",
|
||||
"make_fragment_like",
|
||||
"recast_ptr",
|
||||
"recast_tensor",
|
||||
# Tensor manipulation
|
||||
"get",
|
||||
"select",
|
||||
"front",
|
||||
"is_major",
|
||||
"find",
|
||||
"coalesce",
|
||||
"group_modes",
|
||||
"cosize",
|
||||
"size_in_bytes",
|
||||
# Tuple operations
|
||||
"flatten_to_tuple",
|
||||
"flatten",
|
||||
"product",
|
||||
"product_like",
|
||||
"product_each",
|
||||
"prepend",
|
||||
"append",
|
||||
"prepend_ones",
|
||||
"append_ones",
|
||||
# Math operations
|
||||
"ceil_div",
|
||||
"round_up",
|
||||
# Layout operations
|
||||
"slice_and_offset",
|
||||
"crd2idx",
|
||||
"domain_offset",
|
||||
"elem_less",
|
||||
"filter_zeros",
|
||||
"filter",
|
||||
"tile_to_shape",
|
||||
"shape_div",
|
||||
"dice",
|
||||
# Layout algebra
|
||||
"composition",
|
||||
"complement",
|
||||
"right_inverse",
|
||||
"left_inverse",
|
||||
"max_common_layout",
|
||||
"max_common_vector",
|
||||
"is_congruent",
|
||||
"is_weakly_congruent",
|
||||
# Product operations
|
||||
"logical_product",
|
||||
"zipped_product",
|
||||
"tiled_product",
|
||||
"flat_product",
|
||||
"raked_product",
|
||||
"blocked_product",
|
||||
# Division operations
|
||||
"flat_divide",
|
||||
"logical_divide",
|
||||
"zipped_divide",
|
||||
"tiled_divide",
|
||||
"local_partition",
|
||||
"local_tile",
|
||||
# MMA and Copy operations
|
||||
"make_mma_atom",
|
||||
"make_tiled_mma",
|
||||
"make_copy_atom",
|
||||
"make_tiled_copy_tv",
|
||||
"make_tiled_copy",
|
||||
"make_tiled_copy_C_atom",
|
||||
"basic_copy",
|
||||
"basic_copy_if",
|
||||
"autovec_copy",
|
||||
"copy",
|
||||
"gemm",
|
||||
# Tensor creation
|
||||
"full",
|
||||
"full_like",
|
||||
"empty_like",
|
||||
"ones_like",
|
||||
"zeros_like",
|
||||
"where",
|
||||
"any_",
|
||||
"all_",
|
||||
"repeat_like",
|
||||
"ScaledBasis",
|
||||
# User defined struct
|
||||
"struct",
|
||||
# Modules
|
||||
"arch",
|
||||
"nvgpu",
|
||||
"testing",
|
||||
"runtime",
|
||||
# Decorators and code generation
|
||||
"jit",
|
||||
"kernel",
|
||||
"register_jit_arg_adapter",
|
||||
"compile",
|
||||
]
|
||||
98
python/CuTeDSL/cutlass/cute/arch/__init__.py
Normal file
98
python/CuTeDSL/cutlass/cute/arch/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .elect import *
|
||||
from .mbar import *
|
||||
from .nvvm_wrappers import *
|
||||
from .smem import *
|
||||
from .tmem import *
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
#
|
||||
# elect.py
|
||||
#
|
||||
"make_warp_uniform",
|
||||
"elect_one",
|
||||
#
|
||||
# mbar.py
|
||||
#
|
||||
"mbarrier_init_arrive_cnt",
|
||||
"mbarrier_init_fence",
|
||||
"mbarrier_init_tx_bytes",
|
||||
"mbarrier_wait",
|
||||
"mbarrier_try_wait",
|
||||
"conditional_mbarrier_try_wait",
|
||||
"mbarrier_arrive",
|
||||
#
|
||||
# nvvm_wrappers.py
|
||||
#
|
||||
"lane_idx",
|
||||
"warp_idx",
|
||||
"thread_idx",
|
||||
"block_dim",
|
||||
"block_idx",
|
||||
"grid_dim",
|
||||
"cluster_idx",
|
||||
"cluster_dim",
|
||||
"block_in_cluster_idx",
|
||||
"block_in_cluster_dim",
|
||||
"block_idx_in_cluster",
|
||||
"shuffle_sync",
|
||||
"shuffle_sync_up",
|
||||
"shuffle_sync_down",
|
||||
"shuffle_sync_bfly",
|
||||
"barrier",
|
||||
"sync_threads",
|
||||
"sync_warp",
|
||||
"fence_acq_rel_cta",
|
||||
"fence_acq_rel_cluster",
|
||||
"fence_acq_rel_gpu",
|
||||
"fence_acq_rel_sys",
|
||||
"cp_async_commit_group",
|
||||
"cp_async_wait_group",
|
||||
"cp_async_bulk_commit_group",
|
||||
"cp_async_bulk_wait_group",
|
||||
"cluster_wait",
|
||||
"cluster_arrive",
|
||||
"cluster_arrive_relaxed",
|
||||
"fence_proxy",
|
||||
"vote_ballot_sync",
|
||||
"popc",
|
||||
"fence_view_async_tmem_load",
|
||||
"fence_view_async_tmem_store",
|
||||
"warpgroup_reg_alloc",
|
||||
"warpgroup_reg_dealloc",
|
||||
"fma_packed_f32x2",
|
||||
"mul_packed_f32x2",
|
||||
"add_packed_f32x2",
|
||||
"fmax",
|
||||
"rcp_approx",
|
||||
"exp2",
|
||||
# Constants
|
||||
"WARP_SIZE",
|
||||
# Forward from auto-generated nvvm python
|
||||
"ProxyKind",
|
||||
"SharedSpace",
|
||||
"RoundingModeKind",
|
||||
#
|
||||
# smem.py
|
||||
#
|
||||
"alloc_smem",
|
||||
"get_dyn_smem",
|
||||
#
|
||||
# tmem.py
|
||||
#
|
||||
"retrieve_tmem_ptr",
|
||||
"alloc_tmem",
|
||||
"relinquish_tmem_alloc_permit",
|
||||
"dealloc_tmem",
|
||||
]
|
||||
75
python/CuTeDSL/cutlass/cute/arch/elect.py
Normal file
75
python/CuTeDSL/cutlass/cute/arch/elect.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir.dialects import nvvm, scf
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..typing import Int, Int32
|
||||
from ...impl_utils import check_value_in
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32:
|
||||
"""
|
||||
Creates a warp-uniform value from the given integer input.
|
||||
|
||||
:param value: The integer to make warp uniform.
|
||||
:type value: Int
|
||||
:return: The warp-uniform value equal to the input.
|
||||
:rtype: Int32
|
||||
"""
|
||||
return Int32(
|
||||
_cute_nvgpu_ir.arch_make_warp_uniform(
|
||||
Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class IfOpRegion:
|
||||
"""
|
||||
A context manager for if Op.
|
||||
Automatically inserts `scf.yield([])` when exiting the context.
|
||||
"""
|
||||
|
||||
def __init__(self, block, *, loc=None, ip=None):
|
||||
self.block = block
|
||||
self.insert_point = ir.InsertionPoint(self.block)
|
||||
self.loc = loc
|
||||
self.ip = ip
|
||||
|
||||
def __enter__(self):
|
||||
self.insert_point.__enter__()
|
||||
return self.block.arguments
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
scf.yield_([], loc=self.loc, ip=self.ip)
|
||||
self.insert_point.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def elect_one(*, loc=None, ip=None) -> IfOpRegion:
|
||||
"""
|
||||
Elects one thread within a warp.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with elect_one():
|
||||
# Only one thread in the warp executes the code in this context
|
||||
pass
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
is_thread_leader = nvvm.elect_sync(T.bool())
|
||||
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
|
||||
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
|
||||
208
python/CuTeDSL/cutlass/cute/arch/mbar.py
Normal file
208
python/CuTeDSL/cutlass/cute/arch/mbar.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
|
||||
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..typing import Pointer, Int, Boolean, Int32
|
||||
from ...impl_utils import check_value_in
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Mbarrier management utilities
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_init_arrive_cnt(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Initializes a mbarrier with the specified thread arrival count.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param cnt: The arrival count of the mbarrier
|
||||
:type cnt: Int
|
||||
"""
|
||||
nvvm.mbarrier_init_shared(
|
||||
mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_init_fence(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A fence operation that applies to the mbarrier initializations.
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_init_tx_bytes(
|
||||
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a mbarrier with the specified number of transaction bytes.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param bytes: The number of transaction bytes
|
||||
:type bytes: Int
|
||||
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
||||
mbar_llvm_ptr.type,
|
||||
mbar_llvm_ptr,
|
||||
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
space = nvvm.MBarrierSpaceKind.CLUSTER
|
||||
else:
|
||||
space = nvvm.MBarrierSpaceKind.CTA
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
Int32(bytes).ir_value(loc=loc, ip=ip),
|
||||
kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX,
|
||||
space=space,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Waits on a mbarrier with a specified phase.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param phase: The phase to wait for (either 0 or 1)
|
||||
:type phase: Int
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
|
||||
timeout_ns = 10000000
|
||||
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
|
||||
# The timeout in ns only applies to the latter and this call is truly blocking
|
||||
nvvm.mbarrier_try_wait_parity_shared(
|
||||
mbar_ptr.llvm_ptr,
|
||||
Int32(phase).ir_value(loc=loc, ip=ip),
|
||||
Int32(timeout_ns).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean:
|
||||
"""
|
||||
Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param phase: The phase to wait for (either 0 or 1)
|
||||
:type phase: Int
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
|
||||
return Boolean(
|
||||
nvvm.mbarrier_wait_parity(
|
||||
T.bool(),
|
||||
mbar_ptr.llvm_ptr,
|
||||
Int32(phase).ir_value(loc=loc, ip=ip),
|
||||
nvvm.MBarrierWaitKind.TRY,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def conditional_mbarrier_try_wait(
|
||||
cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
|
||||
) -> Boolean:
|
||||
"""
|
||||
Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
|
||||
|
||||
:param cond: A boolean predicate
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param phase: The phase to wait for (either 0 or 1)
|
||||
:type phase: Int
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
return if_generate(
|
||||
cond,
|
||||
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
|
||||
lambda: Boolean(True).ir_value(loc=loc, ip=ip),
|
||||
None,
|
||||
[Boolean],
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mbarrier_arrive(
|
||||
mbar_ptr: Pointer, peer_cta_rank_in_cluster: Int = None, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Arrives on an mbarrier.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
|
||||
|
||||
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
||||
mbar_llvm_ptr.type,
|
||||
mbar_llvm_ptr,
|
||||
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
space = nvvm.MBarrierSpaceKind.CLUSTER
|
||||
else:
|
||||
space = nvvm.MBarrierSpaceKind.CTA
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
Int32(1).ir_value(loc=loc, ip=ip),
|
||||
kind=nvvm.MBarrierTxnKind.ARRIVE,
|
||||
space=space,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
547
python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py
Normal file
547
python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py
Normal file
@@ -0,0 +1,547 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Union, Callable
|
||||
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm, nvvm, vector
|
||||
|
||||
# Forward nvvm enums
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
ProxyKind,
|
||||
SharedSpace,
|
||||
Tcgen05WaitKind,
|
||||
SetMaxRegisterAction,
|
||||
RoundingModeKind,
|
||||
)
|
||||
|
||||
from ..typing import Int, Boolean, Int32, Float32, Numeric, as_numeric
|
||||
|
||||
WARP_SIZE = 32
|
||||
FULL_MASK = 0xFFFFFFFF
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def lane_idx(*, loc=None, ip=None) -> Int32:
|
||||
"""
|
||||
Returns the lane index of the current thread within the warp.
|
||||
"""
|
||||
return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def warp_idx(*, loc=None, ip=None) -> Int32:
|
||||
"""
|
||||
Returns the warp index within a CTA.
|
||||
"""
|
||||
warp_size = 32
|
||||
tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip))
|
||||
tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip))
|
||||
tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip))
|
||||
ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip))
|
||||
ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip))
|
||||
tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y
|
||||
return tid // warp_size
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the thread index within a CTA.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the number of threads in each dimension of the CTA.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the CTA identifier within a grid.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the number of CTAs in each dimension of the grid.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the cluster identifier within a grid.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the number of clusters in each dimension of the grid.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the CTA index within a cluster across all dimensions.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
Returns the dimensions of the cluster.
|
||||
"""
|
||||
return (
|
||||
Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)),
|
||||
Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)),
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def block_idx_in_cluster(*, loc=None, ip=None) -> Int32:
|
||||
"""
|
||||
Returns the linearized identifier of the CTA within the cluster.
|
||||
"""
|
||||
return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def shuffle_sync_op(
|
||||
value: Numeric,
|
||||
offset: Int,
|
||||
mask: Int = FULL_MASK,
|
||||
mask_and_clamp: Int = WARP_SIZE - 1,
|
||||
kind: nvvm.ShflKind = nvvm.ShflKind.idx,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Numeric:
|
||||
"""
|
||||
Shuffles a value within the threads of a warp.
|
||||
|
||||
:param value: The value to shuffle
|
||||
:type value: Numeric
|
||||
:param mask: A mask describing the threads participating in this operation
|
||||
:type mask: Int
|
||||
:param offset: A source lane or a source lane offset depending on kind
|
||||
:type offset: Int
|
||||
:param mask_and_clamp: An integer containing two packed values specifying a mask for logically
|
||||
splitting warps into sub-segments and an upper bound for clamping the
|
||||
source lane index.
|
||||
:type mask_and_clamp: Int
|
||||
:param kind: The kind of shuffle, can be idx, up, down, or bfly
|
||||
:type kind: ShflKind
|
||||
:return: The shuffled value
|
||||
:rtype: Numeric
|
||||
"""
|
||||
if not isinstance(value, Numeric):
|
||||
value = as_numeric(value)
|
||||
return type(value)(
|
||||
nvvm.shfl_sync(
|
||||
type(value).mlir_type,
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
value.ir_value(loc=loc, ip=ip),
|
||||
Int32(offset).ir_value(loc=loc, ip=ip),
|
||||
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
||||
kind,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
|
||||
shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)
|
||||
shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down)
|
||||
shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Creates a barrier, optionally named.
|
||||
"""
|
||||
if barrier_id is not None:
|
||||
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
|
||||
|
||||
if number_of_threads is not None:
|
||||
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
|
||||
|
||||
nvvm.barrier(
|
||||
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def sync_threads(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Synchronizes all threads within a CTA.
|
||||
"""
|
||||
nvvm.barrier(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Performs a warp-wide sync with an optional mask.
|
||||
"""
|
||||
nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_acq_rel_cta(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Fence operation with acquire-release semantics.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
nvvm.fence_acq_rel_cta(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_acq_rel_cluster(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Fence operation with acquire-release semantics.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
nvvm.fence_acq_rel_cluster(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_acq_rel_gpu(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Fence operation with acquire-release semantics.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
nvvm.fence_acq_rel_gpu(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_acq_rel_sys(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Fence operation with acquire-release semantics.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
nvvm.fence_acq_rel_sys(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_async_commit_group(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Commits all prior initiated but uncommitted cp.async instructions.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group>`__.
|
||||
"""
|
||||
nvvm.cp_async_commit_group(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_async_wait_group(n, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Waits till only a specified numbers of cp.async groups are pending.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group-cp-async-wait-all>`__.
|
||||
"""
|
||||
nvvm.cp_async_wait_group(n, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Commits all prior initiated but uncommitted cp.async.bulk instructions.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-commit-group>`__.
|
||||
"""
|
||||
nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Waits till only a specified numbers of cp.async.bulk groups are pending.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-wait-group>`__.
|
||||
"""
|
||||
nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_wait(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A cluster-wide wait operation.
|
||||
"""
|
||||
nvvm.cluster_wait(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A cluster-wide arrive operation.
|
||||
"""
|
||||
nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A cluster-wide arrive operation with relaxed semantics.
|
||||
"""
|
||||
nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_proxy(
|
||||
kind: ProxyKind,
|
||||
*,
|
||||
space: Optional[SharedSpace] = None,
|
||||
use_intrinsic=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
nvvm.fence_proxy(
|
||||
kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def vote_ballot_sync(
|
||||
pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None
|
||||
) -> Int32:
|
||||
"""
|
||||
Performs a ballot operation across the warp.
|
||||
"""
|
||||
return Int32(
|
||||
nvvm.vote_ballot_sync(
|
||||
T.i32(),
|
||||
Int32(mask).ir_value(loc=loc, ip=ip),
|
||||
Boolean(pred).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def popc(value: Numeric, *, loc=None, ip=None) -> Numeric:
|
||||
"""
|
||||
Performs a population count operation.
|
||||
"""
|
||||
if not isinstance(value, Numeric):
|
||||
value = as_numeric(value)
|
||||
return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_view_async_tmem_op(
|
||||
kind: Tcgen05WaitKind,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform a fence operation on the async TMEM load or store.
|
||||
|
||||
.. note::
|
||||
This function is only available on sm_100a and above.
|
||||
The fence is required to synchronize the TMEM load/store
|
||||
and let the pipeline release or commit the buffer.
|
||||
|
||||
Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM.
|
||||
```
|
||||
# Start to copy ACC from TMEM to register
|
||||
cute.copy(tmem_load, tACC, rACC)
|
||||
fence_view_async_tmem_load()
|
||||
# After fence, we can ensure the TMEM buffer is consumed totally.
|
||||
# Release the buffer to let the MMA know it can overwrite the buffer.
|
||||
mma2accum_pipeline.consumer_release(curr_consumer_state)
|
||||
```
|
||||
Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM.
|
||||
```
|
||||
# Start to copy A from register to TMEM
|
||||
cute.copy(tmem_store, rA, tA)
|
||||
fence_view_async_tmem_store()
|
||||
# After fence, we can ensure the TMEM buffer is ready.
|
||||
# Commit the buffer to let the MMA know it can start to load A.
|
||||
tmem_mma_pipeline.producer_commit(curr_producer_state)
|
||||
```
|
||||
|
||||
|
||||
:param kind: The kind of fence operation to perform including LOAD and STORE.
|
||||
:type kind: Tcgen05WaitKind
|
||||
"""
|
||||
nvvm.tcgen05_wait(kind, loc=loc, ip=ip)
|
||||
|
||||
|
||||
fence_view_async_tmem_load = partial(
|
||||
fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD
|
||||
)
|
||||
fence_view_async_tmem_store = partial(
|
||||
fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def warpgroup_reg_realloc_op(
|
||||
reg_count: int,
|
||||
kind: SetMaxRegisterAction,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip)
|
||||
|
||||
|
||||
warpgroup_reg_alloc = partial(
|
||||
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase
|
||||
)
|
||||
warpgroup_reg_dealloc = partial(
|
||||
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def calc_packed_f32x2_op(
|
||||
src_a: Tuple[Float32, Float32],
|
||||
src_b: Tuple[Float32, Float32],
|
||||
src_c: Tuple[Float32, Float32] | None,
|
||||
calc_func: Callable,
|
||||
*,
|
||||
rnd=RoundingModeKind.RZ,
|
||||
ftz=True,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Float32, Float32]:
|
||||
vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc)
|
||||
vec_src_a = vector.from_elements(
|
||||
vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip
|
||||
)
|
||||
vec_src_b = vector.from_elements(
|
||||
vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip
|
||||
)
|
||||
if src_c is not None:
|
||||
vec_src_c = vector.from_elements(
|
||||
vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip
|
||||
)
|
||||
vec_res = calc_func(
|
||||
vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
vec_res = calc_func(
|
||||
vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
res0 = Float32(
|
||||
vector.extract(
|
||||
vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
res1 = Float32(
|
||||
vector.extract(
|
||||
vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
return res0, res1
|
||||
|
||||
|
||||
fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2)
|
||||
mul_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2
|
||||
)
|
||||
add_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmax(
|
||||
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
|
||||
) -> Float32:
|
||||
return Float32(
|
||||
nvvm.fmax(
|
||||
T.f32(),
|
||||
Float32(a).ir_value(loc=loc, ip=ip),
|
||||
Float32(b).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
|
||||
return Float32(
|
||||
nvvm.rcp_approx_ftz_f(
|
||||
T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
return Float32(
|
||||
llvm.inline_asm(
|
||||
T.f32(),
|
||||
[Float32(a).ir_value(loc=loc, ip=ip)],
|
||||
"ex2.approx.ftz.f32 $0, $1;",
|
||||
"=f,f",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
)
|
||||
96
python/CuTeDSL/cutlass/cute/arch/smem.py
Normal file
96
python/CuTeDSL/cutlass/cute/arch/smem.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..typing import Pointer, Numeric, NumericMeta
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def alloc_smem(
|
||||
element_type: Type[Numeric],
|
||||
size_in_elems: int,
|
||||
alignment: Optional[int] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Pointer:
|
||||
"""
|
||||
Statically allocates SMEM.
|
||||
|
||||
:param element_type: The pointee type of the pointer.
|
||||
:type element_type: Type[Numeric]
|
||||
:param size_in_elems: The size of the allocation in terms of number of elements of the
|
||||
pointee type
|
||||
:type size_in_elems: int
|
||||
:param alignment: An optional pointer alignment for the allocation
|
||||
:type alignment: int
|
||||
:return: A pointer to the start of the allocation
|
||||
:rtype: Pointer
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
f"element_type must be a type of Numeric, but got {element_type}"
|
||||
)
|
||||
|
||||
if alignment is None:
|
||||
# Default alignment based on the element type's width
|
||||
alignment = element_type.width // 8
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment
|
||||
)
|
||||
return _cute_nvgpu_ir.arch_alloc_smem(
|
||||
ptr=ptr_ty,
|
||||
input=ir.IntegerAttr.get(T.i32(), size_in_elems),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_dyn_smem(
|
||||
element_type: Type[Numeric],
|
||||
alignment: Optional[int] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Pointer:
|
||||
"""
|
||||
Retrieves a pointer to a dynamic SMEM allocation.
|
||||
|
||||
:param element_type: The pointee type of the pointer.
|
||||
:type element_type: Type[Numeric]
|
||||
:param alignment: An optional pointer alignment, the result pointer is offset appropriately
|
||||
:type alignment: int
|
||||
:return: A pointer to the start of the dynamic SMEM allocation with a correct
|
||||
alignement
|
||||
:rtype: Pointer
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
f"element_type must be a type of Numeric, but got {element_type}"
|
||||
)
|
||||
|
||||
if alignment is None:
|
||||
# Default alignment based on the element type's width
|
||||
alignment = element_type.width // 8
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
element_type.mlir_type,
|
||||
_cute_ir.AddressSpace.smem,
|
||||
alignment,
|
||||
)
|
||||
return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)
|
||||
142
python/CuTeDSL/cutlass/cute/arch/tmem.py
Normal file
142
python/CuTeDSL/cutlass/cute/arch/tmem.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Type
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
from ..typing import Pointer, Int, Int32, Numeric, NumericMeta
|
||||
|
||||
|
||||
SM100_TMEM_CAPACITY_COLUMNS = 512
|
||||
SM100_TMEM_MIN_ALLOC_COLUMNS = 32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def retrieve_tmem_ptr(
|
||||
element_type: Type[Numeric],
|
||||
alignment: int,
|
||||
ptr_to_buffer_holding_addr: Pointer,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Pointer:
|
||||
"""
|
||||
Retrieves a pointer to TMEM with the provided element type and alignment.
|
||||
|
||||
:param element_type: The pointee type of the pointer.
|
||||
:type element_type: Type[Numeric]
|
||||
:param alignment: The alignment of the result pointer
|
||||
:type alignment: int
|
||||
:param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the
|
||||
start of the allocation allocation
|
||||
:type ptr_to_buffer_holding_addr: Pointer
|
||||
:return: A pointer to TMEM
|
||||
:rtype: Pointer
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
f"element_type must be a type of Numeric, but got {element_type}"
|
||||
)
|
||||
|
||||
res_ty = _cute_ir.PtrType.get(
|
||||
element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment
|
||||
)
|
||||
return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr(
|
||||
res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def alloc_tmem(
|
||||
num_columns: Int,
|
||||
smem_ptr_to_write_address: Pointer,
|
||||
is_two_cta=None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Allocates TMEM.
|
||||
|
||||
:param num_columns: The number of TMEM columns to allocate
|
||||
:type num_columns: Int
|
||||
:param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written
|
||||
to
|
||||
:type smem_ptr_to_write_address: Pointer
|
||||
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
|
||||
"""
|
||||
if isinstance(num_columns, int):
|
||||
if (
|
||||
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
|
||||
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
|
||||
or not (num_columns & (num_columns - 1) == 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
|
||||
)
|
||||
_cute_nvgpu_ir.arch_sm100_alloc_tmem(
|
||||
Int32(num_columns).ir_value(loc=loc, ip=ip),
|
||||
smem_ptr_to_write_address.value,
|
||||
is_two_cta=is_two_cta,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can
|
||||
allocate.
|
||||
"""
|
||||
_cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit(
|
||||
is_two_cta=is_two_cta, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def dealloc_tmem(
|
||||
tmem_ptr: Pointer,
|
||||
num_columns: Int,
|
||||
is_two_cta=None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Deallocates TMEM using the provided pointer and number of columns.
|
||||
|
||||
:param tmem_ptr: A pointer to the TMEM allocation to de-allocate
|
||||
:type tmem_ptr: Pointer
|
||||
:param num_columns: The number of columns in the TMEM allocation
|
||||
:type num_columns: Int
|
||||
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
|
||||
"""
|
||||
if isinstance(num_columns, int):
|
||||
if (
|
||||
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
|
||||
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
|
||||
or not (num_columns & (num_columns - 1) == 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
|
||||
)
|
||||
_cute_nvgpu_ir.arch_sm100_dealloc_tmem(
|
||||
tmem_ptr.value,
|
||||
Int32(num_columns).ir_value(loc=loc, ip=ip),
|
||||
is_two_cta=is_two_cta,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
6417
python/CuTeDSL/cutlass/cute/core.py
Normal file
6417
python/CuTeDSL/cutlass/cute/core.py
Normal file
File diff suppressed because it is too large
Load Diff
354
python/CuTeDSL/cutlass/cute/math.py
Normal file
354
python/CuTeDSL/cutlass/cute/math.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .core import TensorSSA
|
||||
from cutlass._mlir.dialects import math, arith
|
||||
|
||||
|
||||
def acos(a: TensorSSA) -> TensorSSA:
|
||||
"""Compute element-wise arc cosine of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:return: Tensor containing the arc cosine of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = acos(y) # Compute arc cosine
|
||||
"""
|
||||
return TensorSSA(math.acos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def asin(a: TensorSSA) -> TensorSSA:
|
||||
"""Compute element-wise arc sine of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:return: Tensor containing the arc sine of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = asin(y) # Compute arc sine
|
||||
"""
|
||||
return TensorSSA(math.asin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise arc tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc tangent of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = atan(y) # Compute arc tangent
|
||||
"""
|
||||
raise NotImplementedError("atan is not implemented")
|
||||
return TensorSSA(math.atan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise arc tangent of two tensors.
|
||||
|
||||
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
|
||||
between the positive x-axis and the point given by the coordinates (b, a).
|
||||
|
||||
:param a: First input tensor (y-coordinates)
|
||||
:type a: TensorSSA
|
||||
:param b: Second input tensor (x-coordinates)
|
||||
:type b: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc tangent of a/b element-wise
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
y = cute.make_fragment(ptr1, layout).load() # y coordinates
|
||||
x = cute.make_fragment(ptr2, layout).load() # x coordinates
|
||||
theta = atan2(y, x) # Compute angles
|
||||
"""
|
||||
return TensorSSA(
|
||||
math.atan2(a, b, fastmath=arith.FastMathFlags.none), a.shape, a.dtype
|
||||
)
|
||||
|
||||
|
||||
def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise cosine of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the cosine of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = cos(y) # Compute cosine
|
||||
"""
|
||||
return TensorSSA(math.cos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise error function of the input tensor.
|
||||
|
||||
The error function is defined as:
|
||||
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the error function value for each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = erf(y) # Compute error function
|
||||
"""
|
||||
return TensorSSA(math.erf(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise base-2 exponential of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing 2 raised to the power of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = exp2(y) # Compute 2^x
|
||||
"""
|
||||
return TensorSSA(math.exp2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise natural logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the natural logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = log(y) # Compute natural logarithm
|
||||
"""
|
||||
return TensorSSA(math.log(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise base-2 logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the base-2 logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = log2(y) # Compute log base 2
|
||||
"""
|
||||
return TensorSSA(math.log2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise base-10 logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the base-10 logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = log10(y) # Compute log base 10
|
||||
"""
|
||||
return TensorSSA(math.log10(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise reciprocal square root of the input tensor.
|
||||
|
||||
Computes 1/√x element-wise.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the reciprocal square root of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = rsqrt(y) # Compute 1/√x
|
||||
"""
|
||||
return TensorSSA(math.rsqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise sine of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the sine of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = sin(y) # Compute sine
|
||||
"""
|
||||
return TensorSSA(math.sin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise square root of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the square root of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = sqrt(y) # Compute square root
|
||||
"""
|
||||
return TensorSSA(math.sqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def tan(a: TensorSSA) -> TensorSSA:
|
||||
"""Compute element-wise tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:return: Tensor containing the tangent of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = tan(y) # Compute tangent
|
||||
"""
|
||||
return TensorSSA(math.tan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
"""Compute element-wise hyperbolic tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the hyperbolic tangent of each element
|
||||
:rtype: TensorSSA
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = tanh(y) # Compute hyperbolic tangent
|
||||
"""
|
||||
return TensorSSA(math.tanh(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"acos",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"cos",
|
||||
"erf",
|
||||
"exp2",
|
||||
"log",
|
||||
"log10",
|
||||
"log2",
|
||||
"rsqrt",
|
||||
"sin",
|
||||
"sqrt",
|
||||
"tan",
|
||||
"tanh",
|
||||
]
|
||||
26
python/CuTeDSL/cutlass/cute/nvgpu/__init__.py
Normal file
26
python/CuTeDSL/cutlass/cute/nvgpu/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from . import warp
|
||||
from . import cpasync
|
||||
from . import warpgroup
|
||||
from . import tcgen05
|
||||
|
||||
from .common import *
|
||||
from .helpers import *
|
||||
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
"OpError",
|
||||
"MmaUniversalOp",
|
||||
"CopyUniversalOp",
|
||||
]
|
||||
143
python/CuTeDSL/cutlass/cute/nvgpu/common.py
Normal file
143
python/CuTeDSL/cutlass/cute/nvgpu/common.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Type, Optional
|
||||
|
||||
from cutlass.cutlass_dsl import DSLBaseError
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from .. import core
|
||||
from ..typing import Float16, Float32, Float64, Numeric
|
||||
|
||||
|
||||
class OpError(DSLBaseError):
|
||||
"""
|
||||
An exception class for Op construction errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, op: core.Op, message: str, suggestion: Optional[str] = None
|
||||
) -> None:
|
||||
if suggestion is None:
|
||||
# Default suggestion
|
||||
suggestion = "Check your Op construction code"
|
||||
super().__init__(
|
||||
message,
|
||||
error_code=f"{op.__class__.__name__} error",
|
||||
suggestion=suggestion,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# MMA Ops and Traits
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaUniversalOp(core.MmaOp):
|
||||
"""
|
||||
The universal MMA Operation.
|
||||
|
||||
This Operation currently expects the A/B operands as well as the accumulator to share the same
|
||||
data types.
|
||||
|
||||
:param abacc_dtype: The data type for the A/B operands and the accumulator
|
||||
:type abacc_dtype: Type[Numeric]
|
||||
"""
|
||||
|
||||
abacc_dtype: Type[Numeric]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.abacc_dtype not in [Float16, Float32, Float64]:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"universal MMA Operation using FMA"
|
||||
f"\n A/B/Accumulator data type = {self.abacc_dtype}"
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait":
|
||||
shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">')
|
||||
atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get(
|
||||
shape_mnk_attr,
|
||||
self.abacc_dtype.mlir_type,
|
||||
self.abacc_dtype.mlir_type,
|
||||
self.abacc_dtype.mlir_type,
|
||||
)
|
||||
return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class MmaUniversalTrait(core.Trait):
|
||||
pass
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Copy Ops and Traits
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyUniversalOp(core.CopyOp):
|
||||
"""
|
||||
The universal Copy Operation.
|
||||
|
||||
When creating a Copy Atom out of this operation, the expected usage pattern is
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
op = cute.nvgpu.CopyUniversalOp()
|
||||
atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64)
|
||||
|
||||
- ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \
|
||||
or the destination TV Layout) in unit of tensor elements and is used for partitioning by \
|
||||
``TiledCopy`` for example
|
||||
- ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \
|
||||
execution. This can be larger than the width of the above data type. When not provided, \
|
||||
the compiler will do a best effort at auto-vectorizing.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "universal Copy Operation"
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
copy_internal_type: Type[Numeric],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
**kwargs,
|
||||
) -> "CopyUniversalTrait":
|
||||
num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
|
||||
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
|
||||
raise ValueError(
|
||||
"expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
|
||||
f"when creating a copy Atom for {self.__class__.__name__}"
|
||||
)
|
||||
ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
|
||||
copy_internal_type.mlir_type, num_bits_per_copy
|
||||
)
|
||||
return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class CopyUniversalTrait(core.Trait):
|
||||
pass
|
||||
38
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py
Normal file
38
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .copy import *
|
||||
from .helpers import *
|
||||
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
#
|
||||
# copy.py
|
||||
#
|
||||
"LoadCacheMode",
|
||||
"CopyG2SOp",
|
||||
"CopyBulkTensorTileG2SOp",
|
||||
"CopyBulkTensorTileG2SMulticastOp",
|
||||
"CopyBulkTensorTileS2GOp",
|
||||
#
|
||||
# helpers.py
|
||||
#
|
||||
"make_tma_tile_atom",
|
||||
"tma_partition",
|
||||
"create_tma_multicast_mask",
|
||||
"prefetch_descriptor",
|
||||
"copy_tensormap",
|
||||
"update_tma_descriptor",
|
||||
"fence_tma_desc_acquire",
|
||||
"cp_fence_tma_desc_release",
|
||||
"fence_tma_desc_release",
|
||||
]
|
||||
366
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py
Normal file
366
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL, t
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ...core import CopyOp, Trait
|
||||
from ...typing import Int16, Pointer, Integer, Numeric
|
||||
from ..common import OpError
|
||||
from ..tcgen05.mma import CtaGroup
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Aynchronous copies
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class LoadCacheMode(enum.Enum):
|
||||
"""
|
||||
An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators>`__.
|
||||
"""
|
||||
|
||||
ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always
|
||||
GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_
|
||||
STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming
|
||||
LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use
|
||||
NONE = _cute_nvgpu_ir.LoadCacheMode.none
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyG2SOp(CopyOp):
|
||||
"""
|
||||
Non-bulk asynchronous GMEM to SMEM Copy Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-non-bulk-copy>`__.
|
||||
"""
|
||||
|
||||
cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = "cp.async GMEM -> SMEM copy Operation"
|
||||
if self.cache_mode != LoadCacheMode.ALWAYS:
|
||||
res += f"\n with cache mode = {self.cache_mode}"
|
||||
return res
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
copy_internal_type: Type[t.Numeric],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
**kwargs,
|
||||
) -> "CopyG2STrait":
|
||||
num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
|
||||
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
|
||||
raise ValueError(
|
||||
"expects a 'num_bits_per_copy' kw argument of type int that is positive "
|
||||
f"when creating a copy Atom for {self.__class__.__name__}"
|
||||
)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.cache_mode, LoadCacheMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
|
||||
)
|
||||
ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get(
|
||||
copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy
|
||||
)
|
||||
return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class CopyG2STrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Bulk tensor copies a.k.a TMA copies
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
TMA_MBAR_PTR_FIELD_NAME = "tma_bar"
|
||||
TMA_MASK_FIELD_NAME = "mcast_mask"
|
||||
TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr"
|
||||
|
||||
#
|
||||
# TMA GMEM -> SMEM copies
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyBulkTensorTileG2SOp(CopyOp):
|
||||
"""
|
||||
Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
"""
|
||||
|
||||
cta_group: CtaGroup = CtaGroup.ONE
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
raise OpError(
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
|
||||
raise OpError(
|
||||
self,
|
||||
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = "cp.async GMEM -> SMEM bulk tensor copy Operation"
|
||||
if self.cta_group == CtaGroup.TWO:
|
||||
res += f"\n CTA group = 2"
|
||||
return res
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileG2SNonExecTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
||||
if self.cta_group == CtaGroup.ONE:
|
||||
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90
|
||||
elif self.cta_group == CtaGroup.TWO:
|
||||
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm
|
||||
else:
|
||||
assert False, "unrecognized self.cta_group"
|
||||
|
||||
|
||||
class CopyBulkTensorTileG2SNonExecTrait(Trait):
|
||||
# We allow kw args to be dropped so that the user can write common code for non-multicast
|
||||
# and multicast loads.
|
||||
def unpack(
|
||||
self,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
tma_bar_ptr: Optional[Pointer] = None,
|
||||
tma_desc_ptr: Optional[Pointer] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Custom implementation of unpack for non-executable TMAs.
|
||||
|
||||
The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when
|
||||
using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error.
|
||||
"""
|
||||
if not isinstance(tma_bar_ptr, Pointer):
|
||||
raise ValueError(
|
||||
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
|
||||
)
|
||||
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
||||
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>"
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
if isinstance(tma_desc_ptr, Pointer):
|
||||
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
return exec_value
|
||||
|
||||
|
||||
#
|
||||
# TMA GMEM -> SMEM multicast copies
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyBulkTensorTileG2SMulticastOp(CopyOp):
|
||||
"""
|
||||
Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
"""
|
||||
|
||||
cta_group: CtaGroup = CtaGroup.ONE
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
raise OpError(
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
|
||||
raise OpError(
|
||||
self,
|
||||
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation"
|
||||
if self.cta_group == CtaGroup.TWO:
|
||||
res += f"\n CTA group = 2"
|
||||
return res
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
||||
if self.cta_group == CtaGroup.ONE:
|
||||
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast
|
||||
elif self.cta_group == CtaGroup.TWO:
|
||||
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast
|
||||
else:
|
||||
assert False, "unrecognized self.cta_group"
|
||||
|
||||
|
||||
class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
|
||||
def unpack(
|
||||
self,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
tma_bar_ptr: Optional[Pointer] = None,
|
||||
mcast_mask=None,
|
||||
tma_desc_ptr=None,
|
||||
):
|
||||
"""
|
||||
Custom implementation of unpack for non-executable TMAs.
|
||||
|
||||
The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be
|
||||
provided when using `cute.copy`.
|
||||
"""
|
||||
if not isinstance(tma_bar_ptr, Pointer):
|
||||
raise ValueError(
|
||||
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
|
||||
)
|
||||
if not isinstance(mcast_mask, Integer):
|
||||
raise ValueError(
|
||||
"expects a multicast mask to be provided via the mcast_mask kw argument"
|
||||
)
|
||||
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
||||
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<tma_bar>"
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<mcast_mask>"
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
if isinstance(tma_desc_ptr, Pointer):
|
||||
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
return exec_value
|
||||
|
||||
|
||||
#
|
||||
# TMA SMEM -> GMEM copies
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyBulkTensorTileS2GOp(CopyOp):
|
||||
"""
|
||||
Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
"""
|
||||
|
||||
admissible_archs = ["sm_90", "sm_90a", "sm_100a"]
|
||||
|
||||
def __post_init__(self):
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "cp.async SMEM -> GMEM bulk tensor copy Operation"
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyBulkTensorTileS2GTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tma_tile_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
|
||||
class CopyBulkTensorTileS2GTrait(Trait):
|
||||
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
|
||||
"""
|
||||
Custom implementation of unpack for non-executable TMAs.
|
||||
"""
|
||||
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
||||
if isinstance(tma_desc_ptr, Pointer):
|
||||
attr_str = (
|
||||
f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>"
|
||||
)
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
return exec_value
|
||||
327
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py
Normal file
327
python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
|
||||
from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta
|
||||
from ... import core
|
||||
from .copy import (
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileS2GOp,
|
||||
CopyBulkTensorTileG2SNonExecTrait,
|
||||
CopyBulkTensorTileG2SMulticastNonExecTrait,
|
||||
CopyBulkTensorTileS2GTrait,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom(
|
||||
op: Union[
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileS2GOp,
|
||||
],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
cta_tiler: Tiler,
|
||||
num_multicast: int = 1,
|
||||
*,
|
||||
internal_type: Optional[Type[Numeric]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
"""
|
||||
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from and SMEM
|
||||
buffer with the given Layout.
|
||||
|
||||
Given
|
||||
|
||||
- a GMEM tensor
|
||||
- a SMEM layout
|
||||
- a CTA-level Tiler
|
||||
|
||||
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
||||
"TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided
|
||||
layout and consistent with the provided Tiler.
|
||||
|
||||
This function returns two results:
|
||||
|
||||
1. the Copy Atom
|
||||
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \
|
||||
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \
|
||||
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \
|
||||
similarly to any other CuTe tensors using the algebra.
|
||||
|
||||
:param op: The Copy Operation to construct an Atom for
|
||||
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp]
|
||||
:param gmem_tensor: The GMEM tensor involved in the Copy
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: The SMEM layout to construct the Copy Atom for
|
||||
:type smem_layout: Layout
|
||||
:param cta_tiler: The CTA Tiler to use
|
||||
:type cta_tiler: Tiler
|
||||
:param num_multicast: The multicast factor
|
||||
:type num_multicast: int
|
||||
:param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit
|
||||
:type internal_type: Type[Numeric]
|
||||
:return: A Copy Atom for this Operation and the associated TMA tensor
|
||||
:rtype: Tuple[core.CopyAtom, Tensor]
|
||||
"""
|
||||
|
||||
if internal_type is not None:
|
||||
if not isinstance(internal_type, NumericMeta):
|
||||
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
||||
internal_type = internal_type.mlir_type
|
||||
|
||||
cta_v_map = core.composition(
|
||||
core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip),
|
||||
cta_tiler,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
if num_multicast != 1:
|
||||
raise ValueError(
|
||||
f"expects num_multicast to be 1 for non multicast G2S copies, "
|
||||
f"but got {num_multicast}"
|
||||
)
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
op._to_ir(),
|
||||
num_multicast=num_multicast,
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
||||
elif isinstance(op, CopyBulkTensorTileG2SMulticastOp):
|
||||
if num_multicast < 1:
|
||||
raise ValueError(
|
||||
f"expects num_multicast to be >= 1 for multicast G2S copies, "
|
||||
f"but got {num_multicast}"
|
||||
)
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
op._to_ir(),
|
||||
num_multicast=num_multicast,
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return (
|
||||
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
||||
res[1],
|
||||
)
|
||||
elif isinstance(op, CopyBulkTensorTileS2GOp):
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
|
||||
else:
|
||||
raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tma_partition(
|
||||
atom: core.CopyAtom,
|
||||
cta_coord: Coord,
|
||||
cta_layout: Layout,
|
||||
smem_tensor: Tensor,
|
||||
gmem_tensor: Tensor,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom.
|
||||
"""
|
||||
cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip)
|
||||
s, d = _cute_nvgpu_ir.atom_tma_partition(
|
||||
atom._trait.value,
|
||||
cta_coord=cta_coord_val,
|
||||
cta_layout=cta_layout,
|
||||
smem_tensor=smem_tensor.value,
|
||||
gmem_tensor=gmem_tensor.value,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return s, d
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def create_tma_multicast_mask(
|
||||
cta_layout_vmnk: Layout,
|
||||
cta_coord_vmnk: Coord,
|
||||
mcast_mode: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Int16:
|
||||
"""
|
||||
Computes a multicast mask for a TMA load Copy.
|
||||
|
||||
:param cta_layout_vmnk: The VMNK layout of the cluster
|
||||
:type cta_layout_vmnk: Layout
|
||||
:param cta_coord_vmnk: The VMNK coordinate of the current CTA
|
||||
:type cta_coord_vmnk: Coord
|
||||
:param mcast_mode: The tensor mode in which to multicast
|
||||
:type mcast_mode: int
|
||||
:return: The resulting mask
|
||||
:rtype: Int16
|
||||
"""
|
||||
if core.rank(cta_layout_vmnk) != 4:
|
||||
raise ValueError(
|
||||
f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}"
|
||||
)
|
||||
if core.rank(cta_coord_vmnk) != 4:
|
||||
raise ValueError(
|
||||
f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}"
|
||||
)
|
||||
return core.make_layout_image_mask(
|
||||
cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Prefetches the TMA descriptor associated with the TMA Atom.
|
||||
"""
|
||||
_cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def copy_tensormap(
|
||||
tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided
|
||||
pointer.
|
||||
|
||||
:param tma_atom: The TMA Copy Atom
|
||||
:type tma_atom: CopyAtom
|
||||
:param tensormap_ptr: The pointer to the memory location to copy the tensormap to
|
||||
:type tensormap_ptr: Pointer
|
||||
"""
|
||||
_cute_nvgpu_ir.copy_tma_desc(
|
||||
tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def update_tma_descriptor(
|
||||
tma_atom: core.CopyAtom,
|
||||
gmem_tensor: Tensor,
|
||||
tma_desc_ptr: Pointer,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the TMA descriptor in the memory location pointed to by the provided pointer using
|
||||
information from a TMA Copy Atom and the provided GMEM tensor.
|
||||
|
||||
Specifically, the following fields of the TMA descriptor will be updated:
|
||||
|
||||
1. the GMEM tensor base address
|
||||
2. the GMEM tensor shape
|
||||
3. the GMEM tensor stride
|
||||
|
||||
Other fields of the TMA descriptor are left unchanged.
|
||||
|
||||
:param tma_atom: The TMA Copy Atom
|
||||
:type tma_atom: CopyAtom
|
||||
:param gmem_tensor: The GMEM tensor
|
||||
:type gmem_tensor: Tensor
|
||||
:param tensormap_ptr: The pointer to the memory location of the descriptor to udpate
|
||||
:type tensormap_ptr: Pointer
|
||||
"""
|
||||
_cute_nvgpu_ir.update_tma_desc(
|
||||
tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_tma_desc_acquire(
|
||||
tma_desc_ptr: Pointer,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[tma_desc_ptr_i64],
|
||||
"fence.proxy.tensormap::generic.acquire.gpu [$0], 128;",
|
||||
"l",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_fence_tma_desc_release(
|
||||
tma_desc_global_ptr: Pointer,
|
||||
tma_desc_shared_ptr: Pointer,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy>`__.
|
||||
"""
|
||||
tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32],
|
||||
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;",
|
||||
"l,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence_tma_desc_release(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[],
|
||||
"fence.proxy.tensormap::generic.release.gpu;",
|
||||
"",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
159
python/CuTeDSL/cutlass/cute/nvgpu/helpers.py
Normal file
159
python/CuTeDSL/cutlass/cute/nvgpu/helpers.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
from .. import core
|
||||
from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta
|
||||
from ...impl_utils import check_type_in
|
||||
from .cpasync.copy import (
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SNonExecTrait,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileG2SMulticastNonExecTrait,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# TMA creation helpers for tcgen05 MMAs
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom_A(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
mma_tiler_mnk: Shape,
|
||||
tiled_mma: core.TiledMma,
|
||||
cluster_shape_vmnk: Shape,
|
||||
*,
|
||||
internal_type: Optional[Type[Numeric]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
if internal_type is not None:
|
||||
if not isinstance(internal_type, NumericMeta):
|
||||
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
||||
internal_type = internal_type.mlir_type
|
||||
check_type_in(
|
||||
op,
|
||||
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
"op",
|
||||
"make_tma_tile_atom_A",
|
||||
)
|
||||
|
||||
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
||||
mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:])
|
||||
g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip)
|
||||
cta_v_map = tiled_mma._thrfrg_A(g_tile)
|
||||
cta_v_map = core.get(cta_v_map, mode=[1])
|
||||
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
|
||||
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
num_multicast = 1
|
||||
else:
|
||||
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
||||
# multicast across the N-mode since those would share the same tile of A
|
||||
num_multicast = core.size(cluster_shape_vmnk, mode=[2])
|
||||
|
||||
# res[0] = the IR Value for the non-executable atom instance
|
||||
# res[1] = the IR Value for the associated TMA tensor
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
op._to_ir(),
|
||||
num_multicast=num_multicast,
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
||||
else:
|
||||
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
||||
return (
|
||||
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
||||
res[1],
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tma_tile_atom_B(
|
||||
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Layout,
|
||||
mma_tiler_mnk: Shape,
|
||||
tiled_mma: core.TiledMma,
|
||||
cluster_shape_vmnk: Shape,
|
||||
*,
|
||||
internal_type: Optional[Type[Numeric]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[core.CopyAtom, Tensor]:
|
||||
if internal_type is not None:
|
||||
if not isinstance(internal_type, NumericMeta):
|
||||
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
||||
internal_type = internal_type.mlir_type
|
||||
check_type_in(
|
||||
op,
|
||||
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
||||
"op",
|
||||
"make_tma_tile_atom_B",
|
||||
)
|
||||
|
||||
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
||||
mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:])
|
||||
g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip)
|
||||
cta_v_map = tiled_mma._thrfrg_B(g_tile)
|
||||
cta_v_map = core.get(cta_v_map, mode=[1])
|
||||
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
|
||||
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
num_multicast = 1
|
||||
else:
|
||||
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
||||
# multicast across the M-mode since those would share the same tile of B
|
||||
num_multicast = core.size(cluster_shape_vmnk, mode=[1])
|
||||
|
||||
# res[0] = the IR Value for the non-executable atom instance
|
||||
# res[1] = the IR Value for the associated TMA tensor
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
op._to_ir(),
|
||||
num_multicast=num_multicast,
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
if isinstance(op, CopyBulkTensorTileG2SOp):
|
||||
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
||||
else:
|
||||
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
||||
return (
|
||||
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
||||
res[1],
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"make_tma_tile_atom_A",
|
||||
"make_tma_tile_atom_B",
|
||||
]
|
||||
57
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py
Normal file
57
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .copy import *
|
||||
from .mma import *
|
||||
from .helpers import *
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
#
|
||||
# copy.py
|
||||
#
|
||||
"Repetition",
|
||||
"Pack",
|
||||
"Unpack",
|
||||
"Ld16x64bOp",
|
||||
"Ld16x128bOp",
|
||||
"Ld16x256bOp",
|
||||
"Ld16x32bx2Op",
|
||||
"Ld32x32bOp",
|
||||
"St16x64bOp",
|
||||
"St16x128bOp",
|
||||
"St16x256bOp",
|
||||
"St16x32bx2Op",
|
||||
"St32x32bOp",
|
||||
#
|
||||
# mma.py
|
||||
#
|
||||
"OperandMajorMode",
|
||||
"OperandSource",
|
||||
"CtaGroup",
|
||||
"Field",
|
||||
"MmaTF32Op",
|
||||
"MmaF16BF16Op",
|
||||
"MmaI8Op",
|
||||
"MmaFP8Op",
|
||||
"SmemLayoutAtomKind",
|
||||
#
|
||||
# helpers.py
|
||||
#
|
||||
"make_smem_layout_atom",
|
||||
"tile_to_mma_shape",
|
||||
"commit",
|
||||
"is_tmem_load",
|
||||
"is_tmem_store",
|
||||
"get_tmem_copy_properties",
|
||||
"find_tmem_tensor_col_offset",
|
||||
"make_tmem_copy",
|
||||
]
|
||||
465
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py
Normal file
465
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import CopyOp, Trait
|
||||
from ...typing import Numeric
|
||||
|
||||
|
||||
class Repetition(enum.Enum):
|
||||
"""
|
||||
An enumeration for the number of repetitions of a given TMEM copy within the instruction.
|
||||
"""
|
||||
|
||||
x1 = 1
|
||||
x2 = 2
|
||||
x4 = 4
|
||||
x8 = 8
|
||||
x16 = 16
|
||||
x32 = 32
|
||||
x64 = 64
|
||||
x128 = 128
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if isinstance(value, int):
|
||||
if value == 1:
|
||||
return Repetition.x1
|
||||
elif value == 2:
|
||||
return Repetition.x2
|
||||
elif value == 8:
|
||||
return Repetition.x8
|
||||
elif value == 16:
|
||||
return Repetition.x16
|
||||
elif value == 32:
|
||||
return Repetition.x32
|
||||
elif value == 64:
|
||||
return Repetition.x64
|
||||
elif value == 128:
|
||||
return Repetition.x128
|
||||
|
||||
|
||||
class Pack(enum.Enum):
|
||||
"""
|
||||
An enumeration for the possible packing patterns for TMEM to RMEM copies.
|
||||
"""
|
||||
|
||||
NONE = enum.auto()
|
||||
PACK_16b_IN_32b = enum.auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
|
||||
class Unpack(enum.Enum):
|
||||
"""
|
||||
An enumeration for the possible unpacking patterns for RMEM to TMEM copies.
|
||||
"""
|
||||
|
||||
NONE = enum.auto()
|
||||
UNPACK_32b_IN_16b = enum.auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _LdBase(CopyOp):
|
||||
repeat: Repetition = Repetition.x1
|
||||
pack: Pack = Pack.NONE
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
if not isinstance(self.repeat, Repetition):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
|
||||
)
|
||||
if not isinstance(self.pack, Pack):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'pack' Op parameter to be a tcgen05.Pack instance",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = (
|
||||
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
||||
+ f"\n number of repetitions = {self.repeat.value}"
|
||||
)
|
||||
if self.pack == Pack.PACK_16b_IN_32b:
|
||||
res += f"\n with 2x 16-bit to 32b packing"
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ld16x64bOp(_LdBase):
|
||||
"""
|
||||
16x64b TMEM load Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.16x64b`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Ld16x64bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
64,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
||||
)
|
||||
return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Ld16x64bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ld16x128bOp(_LdBase):
|
||||
"""
|
||||
16x128b TMEM load Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.16x128b`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.repeat == Repetition.x128:
|
||||
raise OpError(
|
||||
self,
|
||||
"x128 repetition is not supported",
|
||||
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Ld16x128bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
128,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
||||
)
|
||||
return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Ld16x128bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ld16x256bOp(_LdBase):
|
||||
"""
|
||||
16x256b TMEM load Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.16x256b`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.repeat in (Repetition.x128, Repetition.x64):
|
||||
raise OpError(
|
||||
self,
|
||||
"x64 and x128 repetition is not supported",
|
||||
suggestion="choose one of x1, x2, x4, x8, x16, x32",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Ld16x256bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
256,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
||||
)
|
||||
return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Ld16x256bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ld16x32bx2Op(_LdBase):
|
||||
"""
|
||||
16x32bx2 TMEM load Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.16x32bx2`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Ld16x32bx2Trait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
32,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
||||
)
|
||||
return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Ld16x32bx2Trait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ld32x32bOp(_LdBase):
|
||||
"""
|
||||
32x32b TMEM load Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.32x32`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "Ld32x32bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
32,
|
||||
32,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
||||
)
|
||||
return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class Ld32x32bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _StBase(CopyOp):
|
||||
repeat: Repetition
|
||||
unpack: Unpack = Unpack.NONE
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
if not isinstance(self.repeat, Repetition):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
|
||||
)
|
||||
if not isinstance(self.unpack, Unpack):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'pack' Op parameter to be a tcgen05.Unpack instance",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = (
|
||||
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
||||
+ f"\n number of repetitions = {self.repeat.value}"
|
||||
)
|
||||
if self.unpack == Unpack.UNPACK_32b_IN_16b:
|
||||
res += f"\n with 32-bit to 2x 16b unpacking"
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class St16x64bOp(_StBase):
|
||||
"""
|
||||
16x64b TMEM store Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
||||
This Operation corresponds to the ``.16x64`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "St16x64bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
64,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
||||
)
|
||||
return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class St16x64bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class St16x128bOp(_StBase):
|
||||
"""
|
||||
16x128b TMEM store Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
||||
This Operation corresponds to the ``.16x128`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.repeat == Repetition.x128:
|
||||
raise OpError(
|
||||
self,
|
||||
"x128 repetition is not supported",
|
||||
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "St16x128bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
128,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
||||
)
|
||||
return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class St16x128bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class St16x256bOp(_StBase):
|
||||
"""
|
||||
16x256b TMEM store Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
||||
This Operation corresponds to the ``.16x256`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.repeat in (Repetition.x128, Repetition.x64):
|
||||
raise OpError(
|
||||
self,
|
||||
"x64 and x128 repetition is not supported",
|
||||
suggestion="choose one of x1, x2, x4, x8, x16, x32",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "St16x256bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
256,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
||||
)
|
||||
return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class St16x256bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class St16x32bx2Op(_StBase):
|
||||
"""
|
||||
16x32x2b TMEM store Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
||||
This Operation corresponds to the ``.16x32x2`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "St16x32bx2Trait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
32,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
||||
)
|
||||
return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class St16x32bx2Trait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class St32x32bOp(_StBase):
|
||||
"""
|
||||
32x32b TMEM store Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
||||
This Operation corresponds to the ``.32x32`` qualifier.
|
||||
"""
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "St32x32bTrait":
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
32,
|
||||
32,
|
||||
self.repeat.value,
|
||||
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
||||
)
|
||||
return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class St32x32bTrait(Trait):
|
||||
pass
|
||||
301
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py
Normal file
301
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import overload, Type, Tuple, Union
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
|
||||
from ...typing import (
|
||||
Shape,
|
||||
IntTuple,
|
||||
Layout,
|
||||
Tensor,
|
||||
Int,
|
||||
Numeric,
|
||||
NumericMeta,
|
||||
Int16,
|
||||
Int32,
|
||||
)
|
||||
from ... import core
|
||||
from .mma import SmemLayoutAtomKind, CtaGroup
|
||||
from .copy import (
|
||||
Pack,
|
||||
Unpack,
|
||||
Ld16x64bOp,
|
||||
Ld16x128bOp,
|
||||
Ld16x256bOp,
|
||||
Ld16x32bx2Op,
|
||||
Ld32x32bOp,
|
||||
St16x64bOp,
|
||||
St16x128bOp,
|
||||
St16x256bOp,
|
||||
St16x32bx2Op,
|
||||
St32x32bOp,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Helper functions for MMA
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_atom(
|
||||
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
|
||||
) -> core.ComposedLayout:
|
||||
"""
|
||||
Makes a SMEM layout Atom.
|
||||
|
||||
This function creates a composed layout in unit of elements consistent with the requested layout
|
||||
Atom kind and element data type.
|
||||
|
||||
:param kind: The kind of layout Atom
|
||||
:type kind: SmemLayoutAtomKind
|
||||
:param element_type: The element data type to construct the layout for
|
||||
:type element_type: Type[Numeric]
|
||||
:return: The SMEM layout atom
|
||||
:rtype: core.ComposedLayout
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
|
||||
|
||||
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
|
||||
num_contiguous_bits = 128
|
||||
sw = core.make_swizzle(0, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
|
||||
num_contiguous_bits = 256
|
||||
sw = core.make_swizzle(1, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
|
||||
num_contiguous_bits = 512
|
||||
sw = core.make_swizzle(2, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
|
||||
num_contiguous_bits = 1024
|
||||
sw = core.make_swizzle(3, 4, 3)
|
||||
elif kind == SmemLayoutAtomKind.MN_SW128_32B:
|
||||
num_contiguous_bits = 1024
|
||||
sw = core.make_swizzle(2, 5, 2)
|
||||
else:
|
||||
raise ValueError("unrecognized SMEM layout atom kind")
|
||||
num_contiguous_elems = num_contiguous_bits // element_type.width
|
||||
|
||||
if kind in (
|
||||
SmemLayoutAtomKind.MN_INTER,
|
||||
SmemLayoutAtomKind.MN_SW32,
|
||||
SmemLayoutAtomKind.MN_SW64,
|
||||
SmemLayoutAtomKind.MN_SW128,
|
||||
SmemLayoutAtomKind.MN_SW128_32B,
|
||||
):
|
||||
# M/N-major layout
|
||||
return core.make_composed_layout(
|
||||
sw,
|
||||
0,
|
||||
core.make_layout(
|
||||
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
# K-major layout
|
||||
return core.make_composed_layout(
|
||||
sw,
|
||||
0,
|
||||
core.make_layout(
|
||||
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def tile_to_mma_shape(
|
||||
atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
|
||||
) -> Layout: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tile_to_mma_shape(
|
||||
atom: core.ComposedLayout,
|
||||
mma_tile_shape: Shape,
|
||||
order: IntTuple = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> core.ComposedLayout: ...
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tile_to_mma_shape(
|
||||
atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Tiles a layout to an MMA shape.
|
||||
"""
|
||||
# Default order is colexicographical
|
||||
if order is None:
|
||||
order = tuple(range(core.rank(mma_tile_shape) - 1))
|
||||
if core.rank(order) != core.rank(mma_tile_shape) - 1:
|
||||
raise ValueError(
|
||||
f"rank(order)={core.rank(order)} must be equal to "
|
||||
f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}"
|
||||
)
|
||||
order_val = core._pack_int_tuple(order, loc=loc, ip=ip)
|
||||
mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip)
|
||||
|
||||
if not (
|
||||
core.is_static(atom)
|
||||
and core.is_static(mma_tile_shape_val)
|
||||
and core.is_static(order_val)
|
||||
):
|
||||
raise ValueError("tile_to_mma_shape only supports static inputs")
|
||||
|
||||
res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val)
|
||||
return _cute_ir.static(res_ty, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def commit(
|
||||
mbar_ptr: core.Pointer,
|
||||
mask=None,
|
||||
cta_group: CtaGroup = CtaGroup.ONE,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform an arrive operation on a mbarrier upon completion of previous MMA operations.
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
:param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to
|
||||
:type mask: Int
|
||||
"""
|
||||
if cta_group == CtaGroup.ONE:
|
||||
group = nvvm.Tcgen05GroupKind.CTA_1
|
||||
else:
|
||||
assert cta_group == CtaGroup.TWO
|
||||
group = nvvm.Tcgen05GroupKind.CTA_2
|
||||
|
||||
mbar_ptr = mbar_ptr.llvm_ptr
|
||||
if mask is not None:
|
||||
mask = Int16(mask).ir_value(loc=loc, ip=ip)
|
||||
nvvm.tcgen05_commit_arrive(
|
||||
mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip)
|
||||
return
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# Helper functions for Copies
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
def is_tmem_load(atom: core.CopyAtom) -> bool:
|
||||
"""
|
||||
Returns whether a CopyAtom instance is a TMEM load.
|
||||
"""
|
||||
return isinstance(
|
||||
atom.op,
|
||||
(
|
||||
Ld16x64bOp,
|
||||
Ld16x128bOp,
|
||||
Ld16x256bOp,
|
||||
Ld16x32bx2Op,
|
||||
Ld32x32bOp,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def is_tmem_store(atom: core.CopyAtom) -> bool:
|
||||
"""
|
||||
Returns whether a CopyAtom instance is a TMEM store.
|
||||
"""
|
||||
return isinstance(
|
||||
atom.op,
|
||||
(
|
||||
St16x64bOp,
|
||||
St16x128bOp,
|
||||
St16x256bOp,
|
||||
St16x32bx2Op,
|
||||
St32x32bOp,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_tmem_copy_properties(
|
||||
atom: core.CopyAtom,
|
||||
) -> Tuple[int, int, int, Union[Pack, Unpack]]:
|
||||
"""
|
||||
Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions,
|
||||
and whether packing/unpacking is used).
|
||||
"""
|
||||
if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)):
|
||||
num_dp, num_bits = 16, 64
|
||||
elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)):
|
||||
num_dp, num_bits = 16, 128
|
||||
elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)):
|
||||
num_dp, num_bits = 16, 256
|
||||
elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)):
|
||||
num_dp, num_bits = 16, 32
|
||||
elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)):
|
||||
num_dp, num_bits = 32, 32
|
||||
else:
|
||||
raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}")
|
||||
if is_tmem_load(atom):
|
||||
return num_dp, num_bits, atom.op.repeat.value, atom.op.pack
|
||||
else:
|
||||
assert is_tmem_store(atom), "atom must be a TMEM store"
|
||||
return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int:
|
||||
"""
|
||||
Computes the TMEM column offset given a TMEM tensor.
|
||||
|
||||
:param tmem_tensor: The TMEM tensor to use to compute the columns offset
|
||||
:type tmem_tensor: Tensor
|
||||
:return: The columns offset
|
||||
:rtype: Int
|
||||
"""
|
||||
tmem_col_mask = 0x0000FFFF
|
||||
offset = (
|
||||
core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip)
|
||||
& tmem_col_mask
|
||||
)
|
||||
if isinstance(offset, int):
|
||||
return offset
|
||||
return Int32(offset, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tmem_copy(
|
||||
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
|
||||
) -> core.TiledCopy:
|
||||
"""
|
||||
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
|
||||
"""
|
||||
tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy(
|
||||
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
|
||||
)
|
||||
new_trait = type(atom._trait)(tiled_copy_val)
|
||||
return core.TiledCopy(atom.op, new_trait)
|
||||
603
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py
Normal file
603
python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth
|
||||
from ...typing import (
|
||||
Shape,
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Float16,
|
||||
BFloat16,
|
||||
Float32,
|
||||
TFloat32,
|
||||
Boolean,
|
||||
Int8,
|
||||
Uint8,
|
||||
Int32,
|
||||
Numeric,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# MMA Ops and Traits
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class OperandMajorMode(enum.Enum):
|
||||
"""
|
||||
An enumeration for the majorness of the input operands of the MMA.
|
||||
"""
|
||||
|
||||
MN = _cute_ir.MajorMode.mn
|
||||
K = _cute_ir.MajorMode.k
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if isinstance(value, str):
|
||||
value = value.upper()
|
||||
if value == "MN":
|
||||
return OperandMajorMode.MN
|
||||
elif value == "K":
|
||||
return OperandMajorMode.K
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MajorMode:
|
||||
return self.value
|
||||
|
||||
|
||||
class OperandSource(enum.Enum):
|
||||
"""
|
||||
An enumeration for the source memory location of the A input operand of the MMA.
|
||||
"""
|
||||
|
||||
TMEM = _cute_ir.MmaFragKind.tmem
|
||||
SMEM = _cute_ir.MmaFragKind.smem_desc
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MmaFragKind:
|
||||
return self.value
|
||||
|
||||
|
||||
class CtaGroup(enum.Enum):
|
||||
"""
|
||||
An enumeration for the ``cta_group`` qualifier of the MMA.
|
||||
"""
|
||||
|
||||
ONE = 1
|
||||
TWO = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
|
||||
class Field(enum.Enum):
|
||||
"""
|
||||
An enumeration for the fields of the MMA Atom that can be modified at runtime.
|
||||
"""
|
||||
|
||||
NEGATE_A = "neg_a"
|
||||
NEGATE_B = "neg_b"
|
||||
ACCUMULATE = "accum_c"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir_field_name(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
# Base class for all tcgen05 MMA Ops used to factor out some internal code
|
||||
@dataclass(frozen=True)
|
||||
class MmaOp(MmaOp):
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
shape_mnk: Shape
|
||||
cta_group: CtaGroup
|
||||
a_src: OperandSource
|
||||
a_major_mode: OperandMajorMode
|
||||
b_major_mode: OperandMajorMode
|
||||
|
||||
admissible_archs = ["sm_100a"]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.cta_group, CtaGroup):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
||||
)
|
||||
if not isinstance(self.a_src, OperandSource):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
|
||||
)
|
||||
if not isinstance(self.a_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
||||
)
|
||||
if not isinstance(self.b_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
||||
)
|
||||
# Verify the instruction shape
|
||||
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
||||
f"but got {self.shape_mnk}",
|
||||
)
|
||||
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
||||
if self.cta_group == CtaGroup.ONE:
|
||||
if m not in [64, 128]:
|
||||
raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
|
||||
if m == 64:
|
||||
if (n < 8) or (n > 256) or (n % 8 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
|
||||
)
|
||||
elif m == 128:
|
||||
if (n < 16) or (n > 256) or (n % 16 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}",
|
||||
)
|
||||
else:
|
||||
if m not in [128, 256]:
|
||||
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
|
||||
if (n < 32) or (n > 256) or (n % 32 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
self.__class__.descriptive_name # type: ignore
|
||||
+ f"\n A data type = {self.a_dtype}"
|
||||
+ f"\n B data type = {self.b_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n CTA group = {self.cta_group}"
|
||||
+ f"\n A source location = {self.a_src}"
|
||||
+ f"\n A major mode = {self.a_major_mode}"
|
||||
+ f"\n B major mode = {self.b_major_mode}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
|
||||
class MmaTrait(Trait):
|
||||
admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]
|
||||
|
||||
def set(self, field, value, *, loc=None, ip=None) -> None:
|
||||
if field not in self.admissible_fields:
|
||||
raise ValueError(
|
||||
f"expects field to be one of {self.admissible_fields}, but got {field}"
|
||||
)
|
||||
field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>"
|
||||
attr = ir.Attribute.parse(field_name)
|
||||
self.value = _cute_nvgpu_ir.atom_set_value(
|
||||
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# TF32 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaTF32Op(MmaOp):
|
||||
"""
|
||||
TF32 tcgen05 MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::tf32`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 TF32 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
TFloat32,
|
||||
TFloat32,
|
||||
Float32,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Verify the instruction shape
|
||||
instruction_k = 8
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
0,
|
||||
)
|
||||
return MmaTF32Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaTF32Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# F16/BF16 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaF16BF16Op(MmaOp):
|
||||
"""
|
||||
F16/BF16 tcgen05 MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::f16`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 F16/BF16 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Float16, BFloat16]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
||||
)
|
||||
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
||||
# Accumulator data type verification
|
||||
if self.acc_dtype not in [Float16, Float32]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
||||
)
|
||||
# Instruction shape verification
|
||||
instruction_k = 16
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
0,
|
||||
)
|
||||
return MmaF16BF16Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaF16BF16Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# I8 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaI8Op(MmaOp):
|
||||
"""
|
||||
I8 tcgen05 MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
This Operation corresponds to the ``.kind::i8`` qualifier.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 I8 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
Int32,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Int8, Uint8]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8",
|
||||
)
|
||||
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
||||
# Instruction shape verification
|
||||
instruction_k = 32
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
(T.si8() if self.a_dtype.signed else T.ui8()),
|
||||
(T.si8() if self.b_dtype.signed else T.ui8()),
|
||||
T.si32(),
|
||||
self.a_src._to_ir(),
|
||||
0,
|
||||
)
|
||||
return MmaI8Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaI8Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# F8F6F4 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaFP8Op(MmaOp):
|
||||
"""
|
||||
F8 tcgen05 MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
||||
"""
|
||||
|
||||
descriptive_name = "tcgen05 F8 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
cta_group: CtaGroup,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
instruction_shape,
|
||||
cta_group,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
||||
)
|
||||
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
||||
# Accumulator data type verification
|
||||
if self.acc_dtype not in [Float16, Float32]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
||||
)
|
||||
# Instruction shape verification
|
||||
instruction_k = 32
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.cta_group.value,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
0,
|
||||
)
|
||||
return MmaFP8Trait(
|
||||
_cute_nvgpu_ir.make_sm100_mma(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaFP8Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# SMEM layout atoms
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class SmemLayoutAtomKind(enum.Enum):
|
||||
"""
|
||||
Enum class for the kinds of SMEM layout atoms for SM100.
|
||||
|
||||
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be
|
||||
used to construct an SMEM layout using blocked product for operand A or B such that the
|
||||
resulting layout is legal for both TMA and UMMA.
|
||||
|
||||
Note that there are other ways of creating legal layouts for operand A and B.
|
||||
"""
|
||||
|
||||
MN_INTER = enum.auto()
|
||||
MN_SW32 = enum.auto()
|
||||
MN_SW64 = enum.auto()
|
||||
MN_SW128 = enum.auto()
|
||||
MN_SW128_32B = enum.auto()
|
||||
K_INTER = enum.auto()
|
||||
K_SW32 = enum.auto()
|
||||
K_SW64 = enum.auto()
|
||||
K_SW128 = enum.auto()
|
||||
25
python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py
Normal file
25
python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .copy import *
|
||||
from .mma import *
|
||||
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
# mma.py
|
||||
"MmaF16BF16Op",
|
||||
# copy.py
|
||||
"LdMatrix8x8x16bOp",
|
||||
"LdMatrix16x16x8bOp",
|
||||
"StMatrix8x8x16bOp",
|
||||
"StMatrix16x8x8bOp",
|
||||
]
|
||||
189
python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py
Normal file
189
python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import CopyOp, Trait, _pack_shape
|
||||
from ...typing import Numeric
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseOp(CopyOp):
|
||||
transpose: bool = False
|
||||
num_matrices: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.transpose, bool):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'transpose' Op parameter to be a bool instance",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = (
|
||||
f"{self.__class__.__name__[:-2]} Copy Operation"
|
||||
+ f"\n number of matrices = {self.num_matrices}"
|
||||
)
|
||||
if self.transpose:
|
||||
res += f"\n transposed"
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LdMatrix8x8x16bOp(BaseOp):
|
||||
"""
|
||||
8x8 ``ldmatrix`` Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
|
||||
This operation corresponds to the ``.m8n8`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.num_matrices not in [1, 2, 4]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "LdMatrix8x8x16bTrait":
|
||||
mode = _pack_shape((8, 8), loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
mode.type.attribute,
|
||||
_cute_nvgpu_ir.LdsmSzPattern.u16,
|
||||
self.num_matrices,
|
||||
ir.UnitAttr.get() if self.transpose else None,
|
||||
)
|
||||
return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class LdMatrix8x8x16bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LdMatrix16x16x8bOp(BaseOp):
|
||||
"""
|
||||
16x16 8-bit ``ldmatrix`` Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
|
||||
This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers.
|
||||
"""
|
||||
|
||||
def __init__(self, num_matrices: int) -> None:
|
||||
super().__init__(transpose=True, num_matrices=num_matrices)
|
||||
self._verify()
|
||||
|
||||
def _verify(self):
|
||||
assert self.transpose, "transpose must be True"
|
||||
if self.num_matrices not in [1, 2]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'num_matrices' Op parameter to be one of [1,2]",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "LdMatrix16x16x8bTrait":
|
||||
mode = _pack_shape((16, 16), loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
mode.type.attribute,
|
||||
_cute_nvgpu_ir.LdsmSzPattern.u8,
|
||||
self.num_matrices,
|
||||
ir.UnitAttr.get(),
|
||||
)
|
||||
return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class LdMatrix16x16x8bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StMatrix8x8x16bOp(BaseOp):
|
||||
"""
|
||||
8x8 ``stmatrix`` Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
|
||||
This operation corresponds to the ``m8n8`` qualifier.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.num_matrices not in [1, 2, 4]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "StMatrix8x8x16bTrait":
|
||||
mode = _pack_shape((8, 8), loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
mode.type.attribute,
|
||||
self.num_matrices,
|
||||
ir.UnitAttr.get() if self.transpose else None,
|
||||
)
|
||||
return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class StMatrix8x8x16bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StMatrix16x8x8bOp(BaseOp):
|
||||
"""
|
||||
16x8 ``stmatrix`` Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
|
||||
This operation corresponds to the ``m16n8`` qualifier.
|
||||
"""
|
||||
|
||||
def __init__(self, num_matrices: int) -> None:
|
||||
super().__init__(transpose=True, num_matrices=num_matrices)
|
||||
self._verify()
|
||||
|
||||
def _verify(self):
|
||||
if self.num_matrices not in [1, 2, 4]:
|
||||
assert self.transpose, "transpose must be True"
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "StMatrix16x8x8bTrait":
|
||||
mode = _pack_shape((16, 8), loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
mode.type.attribute,
|
||||
self.num_matrices,
|
||||
ir.UnitAttr.get(),
|
||||
)
|
||||
return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class StMatrix16x8x8bTrait(Trait):
|
||||
pass
|
||||
78
python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py
Normal file
78
python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape
|
||||
from ...typing import Shape, Float16, BFloat16, Float32, Numeric
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaF16BF16Op(MmaOp):
|
||||
"""
|
||||
F16/BF16 tcgen05 MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
|
||||
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
|
||||
"""
|
||||
|
||||
ab_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
shape_mnk: Shape
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.ab_dtype not in [Float16, BFloat16]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
||||
)
|
||||
if self.acc_dtype not in [Float16, Float32]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
||||
)
|
||||
if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
|
||||
)
|
||||
if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM80Type.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
)
|
||||
return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"warp-level F16/BF16 MMA Operation"
|
||||
+ f"\n A/B data type = {self.ab_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
|
||||
class MmaF16BF16Trait(Trait):
|
||||
pass
|
||||
29
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py
Normal file
29
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .mma import *
|
||||
from .helpers import *
|
||||
|
||||
# __all__ is required here for documentation generation
|
||||
__all__ = [
|
||||
# mma.py
|
||||
"OperandMajorMode",
|
||||
"OperandSource",
|
||||
"Field",
|
||||
"MmaF16BF16Op",
|
||||
"MmaF8Op",
|
||||
"SmemLayoutAtomKind",
|
||||
# helpers.py
|
||||
"make_smem_layout_atom",
|
||||
"fence",
|
||||
"commit_group",
|
||||
"wait_group",
|
||||
]
|
||||
109
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py
Normal file
109
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Type
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
|
||||
from ...typing import Numeric, NumericMeta
|
||||
from ... import core
|
||||
from .mma import SmemLayoutAtomKind
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_atom(
|
||||
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
|
||||
) -> core.ComposedLayout:
|
||||
"""
|
||||
Makes a SMEM layout Atom.
|
||||
|
||||
This function creates a composed layout in unit of elements consistent with the requested layout
|
||||
Atom kind and element data type.
|
||||
|
||||
:param kind: The kind of layout Atom
|
||||
:type kind: SmemLayoutAtomKind
|
||||
:param element_type: The element data type to construct the layout for
|
||||
:type element_type: Type[Numeric]
|
||||
:return: The SMEM layout atom
|
||||
:rtype: core.ComposedLayout
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
|
||||
|
||||
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
|
||||
num_contiguous_bits = 128
|
||||
sw = core.make_swizzle(0, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
|
||||
num_contiguous_bits = 256
|
||||
sw = core.make_swizzle(1, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
|
||||
num_contiguous_bits = 512
|
||||
sw = core.make_swizzle(2, 4, 3)
|
||||
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
|
||||
num_contiguous_bits = 1024
|
||||
sw = core.make_swizzle(3, 4, 3)
|
||||
else:
|
||||
raise ValueError("unrecognized SMEM layout atom kind")
|
||||
num_contiguous_elems = num_contiguous_bits // element_type.width
|
||||
|
||||
if kind in (
|
||||
SmemLayoutAtomKind.MN_INTER,
|
||||
SmemLayoutAtomKind.MN_SW32,
|
||||
SmemLayoutAtomKind.MN_SW64,
|
||||
SmemLayoutAtomKind.MN_SW128,
|
||||
):
|
||||
# M/N-major layout
|
||||
return core.make_composed_layout(
|
||||
sw,
|
||||
0,
|
||||
core.make_layout(
|
||||
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
# K-major layout
|
||||
return core.make_composed_layout(
|
||||
sw,
|
||||
0,
|
||||
core.make_layout(
|
||||
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fence(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-fence>`__.
|
||||
"""
|
||||
nvvm.wgmma_fence_aligned(loc=None, ip=None)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def commit_group(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group>`__.
|
||||
"""
|
||||
nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def wait_group(group, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-wait-group>`__.
|
||||
"""
|
||||
nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip)
|
||||
380
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py
Normal file
380
python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth
|
||||
from ...typing import (
|
||||
Shape,
|
||||
Float16,
|
||||
BFloat16,
|
||||
Float32,
|
||||
Boolean,
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Numeric,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# MMA Ops and Traits
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class OperandMajorMode(enum.Enum):
|
||||
"""
|
||||
An enumeration for the majorness of the input operands of the MMA.
|
||||
"""
|
||||
|
||||
MN = _cute_ir.MajorMode.mn
|
||||
K = _cute_ir.MajorMode.k
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if isinstance(value, str):
|
||||
value = value.upper()
|
||||
if value == "MN":
|
||||
return OperandMajorMode.MN
|
||||
elif value == "K":
|
||||
return OperandMajorMode.K
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MajorMode:
|
||||
return self.value
|
||||
|
||||
|
||||
class OperandSource(enum.Enum):
|
||||
"""
|
||||
An enumeration for the source memory location of the A input operand of the MMA.
|
||||
"""
|
||||
|
||||
RMEM = _cute_ir.MmaFragKind.rmem
|
||||
SMEM = _cute_ir.MmaFragKind.smem_desc
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MmaFragKind:
|
||||
return self.value
|
||||
|
||||
|
||||
class Field(enum.Enum):
|
||||
"""
|
||||
An enumeration for the fields of the MMA Atom that can be modified at runtime.
|
||||
"""
|
||||
|
||||
ACCUMULATE = "accum_c"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir_field_name(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaOp(MmaOp):
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
shape_mnk: Shape
|
||||
a_src: OperandSource
|
||||
a_major_mode: OperandMajorMode
|
||||
b_major_mode: OperandMajorMode
|
||||
|
||||
admissible_archs = ["sm_90a"]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.a_src, OperandSource):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance",
|
||||
)
|
||||
if not isinstance(self.a_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
|
||||
)
|
||||
if not isinstance(self.b_major_mode, OperandMajorMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
|
||||
)
|
||||
# Verify instruction shape
|
||||
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
||||
f"but got {self.shape_mnk}",
|
||||
)
|
||||
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
||||
if m != 64:
|
||||
raise OpError(self, f"expects the M-mode to be 64, but got {m}")
|
||||
if (n < 8) or (n > 256) or (n % 8 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
self.__class__.descriptive_name # type: ignore
|
||||
+ f"\n A data type = {self.a_dtype}"
|
||||
+ f"\n B data type = {self.b_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n A source location = {self.a_src}"
|
||||
+ f"\n A major mode = {self.a_major_mode}"
|
||||
+ f"\n B major mode = {self.b_major_mode}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
|
||||
class MmaTrait(Trait):
|
||||
admissible_fields = [Field.ACCUMULATE]
|
||||
|
||||
def set(self, field, value, *, loc=None, ip=None) -> None:
|
||||
if field not in self.admissible_fields:
|
||||
raise ValueError(
|
||||
f"invalid field, must be {Field.ACCUMULATE}, but got {field}"
|
||||
)
|
||||
field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>"
|
||||
attr = ir.Attribute.parse(field_name)
|
||||
self.value = _cute_nvgpu_ir.atom_set_value(
|
||||
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaF16BF16Op(MmaOp):
|
||||
"""
|
||||
F16/BF16 warpgroup MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
|
||||
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
|
||||
"""
|
||||
|
||||
descriptive_name = "warpgroup F16/BF16 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
instruction_shape,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self) -> None:
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Float16, BFloat16]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
||||
)
|
||||
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
||||
# Accumulator data type verification
|
||||
if self.acc_dtype not in [Float16, Float32]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
||||
)
|
||||
if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
|
||||
)
|
||||
# Verify the instruction shape
|
||||
instruction_k = 16
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
)
|
||||
return MmaF16BF16Trait(
|
||||
_cute_nvgpu_ir.make_sm90_mma(
|
||||
ty,
|
||||
Boolean(False).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaF16BF16Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaF8Op(MmaOp):
|
||||
"""
|
||||
F16/BF16 warpgroup MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
|
||||
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
|
||||
"""
|
||||
|
||||
descriptive_name = "warpgroup F8 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a_dtype: Type[Numeric],
|
||||
b_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
instruction_shape: Shape,
|
||||
a_src: OperandSource,
|
||||
a_major_mode: OperandMajorMode,
|
||||
b_major_mode: OperandMajorMode,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
acc_dtype,
|
||||
instruction_shape,
|
||||
a_src,
|
||||
a_major_mode,
|
||||
b_major_mode,
|
||||
)
|
||||
self._verify()
|
||||
|
||||
def _verify(self):
|
||||
# Input data type verification
|
||||
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
||||
)
|
||||
if self.b_dtype not in [Float8E5M2, Float8E4M3FN]:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
||||
)
|
||||
# Accumulator data type verification
|
||||
if self.acc_dtype != Float32:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
)
|
||||
# Verify the instruction shape
|
||||
instruction_k = 32
|
||||
if rank(self.shape_mnk) == 2:
|
||||
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
||||
if self.shape_mnk[2] != instruction_k:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
||||
f"but got {self.shape_mnk[2]}",
|
||||
)
|
||||
|
||||
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.a_major_mode._to_ir(),
|
||||
self.b_major_mode._to_ir(),
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.a_src._to_ir(),
|
||||
)
|
||||
return MmaF8Trait(
|
||||
_cute_nvgpu_ir.make_sm90_mma(
|
||||
ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MmaF8Trait(MmaTrait):
|
||||
pass
|
||||
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# SMEM layout atoms
|
||||
#
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class SmemLayoutAtomKind(enum.Enum):
|
||||
"""
|
||||
Enum class for the kinds of SMEM layout atoms for SM90.
|
||||
|
||||
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can
|
||||
be used to construct an SMEM layout using blocked product for operand A or B such that the
|
||||
resulting layout is legal for both TMA and UMMA.
|
||||
|
||||
Note that there are other ways of creating legal layouts for operand A and B.
|
||||
"""
|
||||
|
||||
MN_INTER = enum.auto()
|
||||
MN_SW32 = enum.auto()
|
||||
MN_SW64 = enum.auto()
|
||||
MN_SW128 = enum.auto()
|
||||
K_INTER = enum.auto()
|
||||
K_SW32 = enum.auto()
|
||||
K_SW64 = enum.auto()
|
||||
K_SW128 = enum.auto()
|
||||
515
python/CuTeDSL/cutlass/cute/runtime.py
Normal file
515
python/CuTeDSL/cutlass/cute/runtime.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import ctypes
|
||||
from functools import lru_cache
|
||||
import itertools
|
||||
import operator
|
||||
from time import time
|
||||
from typing import Union
|
||||
|
||||
# MLIR modules imports
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
|
||||
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
|
||||
|
||||
# Local modules imports
|
||||
from .typing import (
|
||||
AddressSpace,
|
||||
Tensor,
|
||||
Type,
|
||||
Pointer,
|
||||
Boolean,
|
||||
Numeric,
|
||||
Float4E2M1FN,
|
||||
Int64,
|
||||
Int32,
|
||||
Int16,
|
||||
Int8,
|
||||
Uint64,
|
||||
Uint32,
|
||||
Uint16,
|
||||
Uint8,
|
||||
Float64,
|
||||
Float32,
|
||||
Float16,
|
||||
BFloat16,
|
||||
Float8E5M2,
|
||||
)
|
||||
from .core import find, _Tensor as CoreTensor
|
||||
|
||||
|
||||
class _Pointer(Pointer):
|
||||
"""Runtime representation of a pointer that can inter-operate with various data structures,
|
||||
including numpy arrays and device memory.
|
||||
|
||||
:param pointer: The pointer to the data
|
||||
:type pointer: int or pointer-like object
|
||||
:param dtype: Data type of the elements pointed to
|
||||
:type dtype: Type
|
||||
:param mem_space: Memory space where the pointer resides, defaults to generic
|
||||
:type mem_space: _cute_ir.AddressSpace, optional
|
||||
:param assumed_align: Assumed alignment of input pointer in bytes, defaults to None
|
||||
:type assumed_align: int, optional
|
||||
|
||||
:ivar _pointer: The underlying pointer
|
||||
:ivar _dtype: Data type of the elements
|
||||
:ivar _addr_space: Memory space of the pointer
|
||||
:ivar _assumed_align: Alignment of the pointer in bytes
|
||||
:ivar _desc: C-type descriptor for the pointer
|
||||
:ivar _c_pointer: C-compatible pointer representation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pointer,
|
||||
dtype,
|
||||
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
|
||||
assumed_align=None,
|
||||
):
|
||||
self._pointer = pointer
|
||||
self._dtype = dtype
|
||||
self._addr_space = mem_space
|
||||
|
||||
is_in_device = mem_space == _cute_ir.AddressSpace.gmem
|
||||
if assumed_align is None:
|
||||
if is_in_device:
|
||||
self._assumed_align = 32
|
||||
else:
|
||||
self._assumed_align = dtype.width // 8
|
||||
else:
|
||||
self._assumed_align = assumed_align
|
||||
|
||||
class PtrDescriptor(ctypes.Structure):
|
||||
"""A ctype descriptor for CuTe memref ptr"""
|
||||
|
||||
_fields_ = [("ptr", ctypes.c_void_p)]
|
||||
|
||||
def __str__(self):
|
||||
return f"0x{self.ptr:016x}"
|
||||
|
||||
self._desc = PtrDescriptor(int(self._pointer))
|
||||
self._c_pointer = ctypes.cast(ctypes.pointer(self._desc), ctypes.c_void_p)
|
||||
assert (
|
||||
self._desc.ptr % self._assumed_align == 0
|
||||
), f"pointer must be {self._assumed_align} bytes aligned"
|
||||
|
||||
def size_in_bytes(self) -> int:
|
||||
return ctypes.sizeof(self._desc)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.mlir_type]
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [self._c_pointer]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
return values[0]
|
||||
|
||||
# Move mlir Type out of __init__ to decouple with mlir Context
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
return _cute_ir.PtrType.get(
|
||||
self._dtype.mlir_type, self._addr_space, self._assumed_align
|
||||
)
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def memspace(self):
|
||||
return self._addr_space
|
||||
|
||||
def verify(self, expected_py_type):
|
||||
if expected_py_type is Pointer:
|
||||
return True
|
||||
elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Ptr<0x{self._desc.ptr:016x}@{self._addr_space}>"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class _Tensor(Tensor):
|
||||
def __init__(
|
||||
self,
|
||||
tensor,
|
||||
assumed_align=None,
|
||||
):
|
||||
# If tensor is already a DLPack object, use it directly
|
||||
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
|
||||
self._dlpack_data = tensor
|
||||
else:
|
||||
self._dlpack_data = tensor.__dlpack__()
|
||||
self._dltensor_wrapper = None
|
||||
self._assumed_align = assumed_align
|
||||
self._is_dynamic = False
|
||||
self._memref_desc = None
|
||||
self._dtype = None
|
||||
|
||||
@property
|
||||
def __class__(self) -> Type[Tensor]:
|
||||
# Cheat to let `type(_Tensor())` to return cute.Tensor
|
||||
return Tensor
|
||||
|
||||
@staticmethod
|
||||
def lazily_load_dltensor(func):
|
||||
"""Decorator to lazily load the DLTensorWrapper.
|
||||
|
||||
This decorator loads the DLTensorWrapper when needed,
|
||||
avoiding overhead in the critical path of calling JIT functions.
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self._dltensor_wrapper is None:
|
||||
self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@lazily_load_dltensor
|
||||
def mark_layout_dynamic(self, leading_dim: int | None = None):
|
||||
"""Marks the tensor layout as dynamic based on the leading dimension.
|
||||
|
||||
:param leading_dim: The leading dimension of the layout, defaults to None
|
||||
:type leading_dim: int, optional
|
||||
|
||||
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
|
||||
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
|
||||
if the layout cannot be automatically deduced.
|
||||
|
||||
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
|
||||
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
|
||||
with the existing layout by checking that the corresponding stride of that dimension is 1.
|
||||
|
||||
Limitation: only support flat layout for now. Will work on supporting nested layout in the future.
|
||||
|
||||
:return: The tensor with dynamic layout
|
||||
:rtype: _Tensor
|
||||
"""
|
||||
self._dltensor_wrapper.mark_layout_dynamic(leading_dim)
|
||||
return self
|
||||
|
||||
@lazily_load_dltensor
|
||||
def mark_compact_shape_dynamic(
|
||||
self,
|
||||
mode: int,
|
||||
stride_order: tuple[int, ...] | None = None,
|
||||
divisibility: int = 1,
|
||||
):
|
||||
"""Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides.
|
||||
|
||||
:param mode: The mode of the compact shape, defaults to 0
|
||||
:type mode: int
|
||||
:param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None.
|
||||
Indicates the order of the modes (dimensions) if the current layout were converted to row-major order.
|
||||
It starts from the outermost to the innermost dimension.
|
||||
:type stride_order: tuple[int, ...], optional
|
||||
:param divisibility: The divisibility constraint for the compact shape, defaults to 1
|
||||
:type divisibility: int, optional
|
||||
:return: The tensor with dynamic compact shape
|
||||
:rtype: _Tensor
|
||||
|
||||
If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout.
|
||||
Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout).
|
||||
An error is raised if automatic deduction fails.
|
||||
|
||||
If ``stride_order`` is explicitly specified, it does the consistency check with the layout.
|
||||
|
||||
For example:
|
||||
- Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`)
|
||||
- Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`).
|
||||
|
||||
Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor.
|
||||
.. code-block:: python
|
||||
a = torch.empty(3, 4)
|
||||
t = cute.runtime.from_dlpack(a)
|
||||
t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order())
|
||||
"""
|
||||
self._dltensor_wrapper.mark_compact_shape_dynamic(
|
||||
mode, stride_order, divisibility
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
if self._dtype is None:
|
||||
self._dtype = self._dltensor_wrapper.dtype
|
||||
return self._dtype
|
||||
|
||||
@element_type.setter
|
||||
def element_type(self, new_type):
|
||||
"""Set the element type of the tensor.
|
||||
|
||||
:warning: This API is added for narrow precision before we have a clean `recast_tensor` story.
|
||||
|
||||
:note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor
|
||||
from frameworks with storage type like uint8.
|
||||
|
||||
**Example**:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a tensor from a numpy array
|
||||
import numpy as np
|
||||
from cutlass.cute import from_dlpack
|
||||
|
||||
# Create a tensor with Float32 elements
|
||||
a = np.zeros(shape, dtype=np.uint8)
|
||||
tensor = from_dlpack(a)
|
||||
|
||||
# Change the element type to Float4E2M1FN even storage type is uint8
|
||||
tensor.element_type = cutlass.Float4E2M1FN
|
||||
|
||||
src = from_dlpack(... data tensor ...)
|
||||
# convert and initialize narrow precision tensor
|
||||
cute.testing.convert(src, tensor)
|
||||
"""
|
||||
self._dtype = new_type
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def memspace(self):
|
||||
return self._dltensor_wrapper.address_space
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def size_in_bytes(self) -> int:
|
||||
return self._dltensor_wrapper.size_in_bytes()
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def mlir_type(self) -> ir.Type:
|
||||
return self._dltensor_wrapper.get_type(
|
||||
self.element_type.mlir_type, self._assumed_align
|
||||
)
|
||||
|
||||
@lazily_load_dltensor
|
||||
def __str__(self) -> str:
|
||||
return f"Tensor<0x{self._dltensor_wrapper.str}>"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise TypeError(f"runtime._Tensor is not indexable")
|
||||
|
||||
def __getitem__(self, crd):
|
||||
raise TypeError(f"runtime._Tensor is not indexable")
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def iterator(self):
|
||||
return _Pointer(
|
||||
self._dltensor_wrapper.data_ptr,
|
||||
self.element_type,
|
||||
self.memspace,
|
||||
self._assumed_align,
|
||||
)
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
raise NotImplementedError(
|
||||
f"layout property is not supported in runtime, support in future"
|
||||
)
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def shape(self):
|
||||
return self._dltensor_wrapper.shape
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def stride(self):
|
||||
strides = self._dltensor_wrapper.stride
|
||||
if strides is None:
|
||||
strides = itertools.accumulate(
|
||||
reversed(self.shape), func=operator.mul, initial=1
|
||||
)
|
||||
strides = tuple(reversed(list(strides)[:-1]))
|
||||
|
||||
return strides
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=128, typed=True)
|
||||
def leading_dim(self):
|
||||
"""Get the leading dimension of this Tensor.
|
||||
|
||||
:return: The leading dimension index or indices
|
||||
:rtype: int or tuple or None
|
||||
|
||||
The return value depends on the tensor's stride pattern:
|
||||
|
||||
* If a single leading dimension is found, returns an integer index
|
||||
* If nested leading dimensions are found, returns a tuple of indices
|
||||
* If no leading dimension is found, returns None
|
||||
"""
|
||||
return find(1, self.stride, exclude_when=(1, self.shape))
|
||||
|
||||
def fill(self, value: Numeric):
|
||||
raise TypeError(f"fill function is not supported in runtime")
|
||||
|
||||
@property
|
||||
@lazily_load_dltensor
|
||||
def data_ptr(self):
|
||||
return self._dltensor_wrapper.data_ptr
|
||||
|
||||
@lazily_load_dltensor
|
||||
def __c_pointers__(self):
|
||||
self._memref_desc = self._dltensor_wrapper.build_memref_desc(
|
||||
self._assumed_align
|
||||
)
|
||||
return [_cute_ir.pycapsule_get_pointer(self._memref_desc)]
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.mlir_type]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
assert isinstance(values[0], CoreTensor)
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
|
||||
def from_dlpack(
|
||||
tensor_dlpack,
|
||||
assumed_align=None,
|
||||
) -> Tensor:
|
||||
"""Convert from tensor object supporting __dlpack__() to a CuTe Tensor.
|
||||
|
||||
:param tensor_dlpack: Tensor object that supports the DLPack protocol
|
||||
:type tensor_dlpack: object
|
||||
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None,
|
||||
if None, will use the element size bytes as the assumed alignment.
|
||||
:type assumed_align: int, optional
|
||||
:return: A CuTe Tensor object
|
||||
:rtype: Tensor
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
x = torch.randn(100, 100)
|
||||
y = from_dlpack(x)
|
||||
y.shape
|
||||
# (100, 100)
|
||||
type(y)
|
||||
# <class 'cutlass.cute.Tensor'>
|
||||
"""
|
||||
return _Tensor(
|
||||
tensor_dlpack,
|
||||
assumed_align=assumed_align,
|
||||
)
|
||||
|
||||
|
||||
def make_ptr(
|
||||
dtype: Type[Numeric],
|
||||
value: Union[int, ctypes._Pointer],
|
||||
mem_space: AddressSpace = AddressSpace.generic,
|
||||
assumed_align=None,
|
||||
) -> Pointer:
|
||||
"""Create a pointer from a memory address
|
||||
|
||||
:param dtype: Data type of the pointer elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param value: Memory address as integer or ctypes pointer
|
||||
:type value: Union[int, ctypes._Pointer]
|
||||
:param mem_space: Memory address space, defaults to AddressSpace.generic
|
||||
:type mem_space: AddressSpace, optional
|
||||
:param align_bytes: Alignment in bytes, defaults to None
|
||||
:type align_bytes: int, optional
|
||||
:return: A pointer object
|
||||
:rtype: Pointer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
import ctypes
|
||||
|
||||
from cutlass import Float32
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
# Create a numpy array
|
||||
a = np.random.randn(16, 32).astype(np.float32)
|
||||
|
||||
# Get pointer address as integer
|
||||
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
|
||||
# Create pointer from address
|
||||
y = make_ptr(cutlass.Float32, ptr_address)
|
||||
|
||||
# Check properties
|
||||
print(y.element_type)
|
||||
print(type(y)) # <class 'cutlass.cute.Pointer'>
|
||||
"""
|
||||
# check if value is int or ctypes.POINTER
|
||||
if isinstance(value, int):
|
||||
address_value = value
|
||||
elif isinstance(value, ctypes._Pointer):
|
||||
# get address value
|
||||
address_value = ctypes.cast(value, ctypes.c_void_p).value
|
||||
assert address_value is not None, "Pointer address is None"
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
|
||||
)
|
||||
|
||||
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)
|
||||
|
||||
|
||||
class TensorAdapter:
|
||||
"""
|
||||
Convert a DLPack protocol supported tensor/array to a cute tensor.
|
||||
"""
|
||||
|
||||
# Need reference these capsules to avoid being garbage collected
|
||||
tensor_capsules = []
|
||||
|
||||
def __init__(self, arg):
|
||||
self._arg = from_dlpack(arg).mark_layout_dynamic()
|
||||
self.tensor_capsules.append(self._arg)
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return self._arg.__new_from_mlir_values__(values)
|
||||
|
||||
def __c_pointers__(self):
|
||||
return self._arg.__c_pointers__()
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return self._arg.__get_mlir_types__()
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Try to register_jit_arg_adapter for TensorAdapter
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
try: # Register for numpy.ndarray
|
||||
import numpy
|
||||
|
||||
JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter)
|
||||
except ImportError:
|
||||
pass # silent attempt, suppress error
|
||||
|
||||
try: # Register for torch.Tensor
|
||||
import torch
|
||||
|
||||
JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter)
|
||||
except ImportError:
|
||||
pass # silent attempt, suppress error
|
||||
285
python/CuTeDSL/cutlass/cute/testing.py
Normal file
285
python/CuTeDSL/cutlass/cute/testing.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
const,
|
||||
T,
|
||||
CuTeDSL,
|
||||
BaseDSL,
|
||||
t,
|
||||
Constexpr,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.ir as ir
|
||||
from cutlass._mlir.dialects import nvvm, cf, vector, builtin
|
||||
|
||||
from cutlass.cute import core
|
||||
from cutlass.cute import nvgpu
|
||||
from typing import Type
|
||||
from inspect import isclass
|
||||
|
||||
|
||||
def assert_(cond, msg=None):
|
||||
if isinstance(cond, ir.Value):
|
||||
if ir.VectorType.isinstance(cond.type):
|
||||
assert (
|
||||
cond.type.element_type == T.bool()
|
||||
), f"only expects vector type with boolean elements, but got {cond.type}"
|
||||
cond_val = vector.multi_reduction(
|
||||
vector.CombiningKind.AND, cond, const(True), range(cond.type.rank)
|
||||
)
|
||||
else:
|
||||
cond_val = cond
|
||||
else:
|
||||
cond_val = const(cond, t.Boolean)
|
||||
|
||||
cf.assert_(cond_val, msg if msg else "")
|
||||
|
||||
|
||||
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
|
||||
if src.element_type.width == 4:
|
||||
tv_layout = core.recast_layout(8, 4, tv_layout)
|
||||
src = core.recast_tensor(src, dtype=t.Int8)
|
||||
return src, tv_layout
|
||||
|
||||
|
||||
def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]):
|
||||
"""Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit.
|
||||
|
||||
:param input: The input tensor to recast.
|
||||
:param dtype: The target numeric type to potentially recast to.
|
||||
:raises TypeError: If dtype is not a subclass of Numeric.
|
||||
:return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged.
|
||||
"""
|
||||
if not isclass(dtype) or not issubclass(dtype, core.Numeric):
|
||||
raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}")
|
||||
|
||||
if dtype.width == 4:
|
||||
recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape
|
||||
i4_vec = vector.bitcast(
|
||||
T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast()
|
||||
)
|
||||
res_vect = builtin.unrealized_conversion_cast(
|
||||
[T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec]
|
||||
)
|
||||
return core.TensorSSA(res_vect, recast_shape, dtype)
|
||||
return input
|
||||
|
||||
|
||||
def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]):
|
||||
"""Conditionally recasts the tensor from 4-bit type if the source type is 4-bit.
|
||||
|
||||
:param input: The input tensor to recast.
|
||||
:param src_dtype: The source numeric type to potentially recast from.
|
||||
:raises TypeError: If src_dtype is not a subclass of Numeric.
|
||||
:return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged.
|
||||
"""
|
||||
if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric):
|
||||
raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}")
|
||||
|
||||
if src_dtype.width == 4:
|
||||
recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape
|
||||
i4_vec = builtin.unrealized_conversion_cast(
|
||||
[T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()]
|
||||
)
|
||||
res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec)
|
||||
return core.TensorSSA(res_vect, recast_shape, core.Int8)
|
||||
return input
|
||||
|
||||
|
||||
@CuTeDSL.kernel
|
||||
def _convert_kernel(
|
||||
gSrc: core.Tensor,
|
||||
gDst: core.Tensor,
|
||||
cSrc: core.Tensor,
|
||||
src_tv_layout: core.Layout,
|
||||
dst_tv_layout: core.Layout,
|
||||
src_shape: core.Shape,
|
||||
src_ty,
|
||||
dst_ty,
|
||||
):
|
||||
tidx = nvvm.read_ptx_sreg_tid_x(T.i32())
|
||||
bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32())
|
||||
|
||||
cta_coord = (None, bidx)
|
||||
# logical idx -> address
|
||||
ctaSrc = gSrc[cta_coord] # (...,TileV,...)
|
||||
ctaDst = gDst[cta_coord] # (...,TileV,...)
|
||||
ctaCSrc = cSrc[cta_coord] # (...,TileV,...)
|
||||
# print(f"ctaSrc = {ctaSrc.type}")
|
||||
|
||||
# compose with CTA TV layout
|
||||
# tid, vid -> address
|
||||
tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V)
|
||||
tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V)
|
||||
tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V)
|
||||
# print(f"tidfrgSrc = {tidfrgSrc.type}")
|
||||
|
||||
# slice for threads
|
||||
thr_coord = (tidx, None)
|
||||
thrSrc = tidfrgSrc[thr_coord] # (V)
|
||||
thrDst = tidfrgDst[thr_coord] # (V)
|
||||
thrCSrc = tidfrgCSrc[thr_coord] # (V)
|
||||
# print(f"thrSrc = {thrSrc.type}")
|
||||
|
||||
# predicate
|
||||
if core.elem_less(thrCSrc[0], src_shape):
|
||||
# allocate fragments for gmem->rmem
|
||||
frgSrc = core.make_fragment(
|
||||
core.get(src_tv_layout, mode=[1]), gSrc.element_type
|
||||
) # (V)
|
||||
frgDst = core.make_fragment(
|
||||
core.get(dst_tv_layout, mode=[1]), gDst.element_type
|
||||
) # (V)
|
||||
# print(f"frgSrc = {frgSrc.type}")
|
||||
|
||||
# Move data to reg address space
|
||||
copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type)
|
||||
core.copy(copy_atom_load, thrSrc, frgSrc)
|
||||
|
||||
vec_src = frgSrc.load()
|
||||
vec_src = _maybe_recast_to_f4(vec_src, src_ty)
|
||||
vec_dst = vec_src.to(dst_ty)
|
||||
vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty)
|
||||
frgDst.store(vec_dst)
|
||||
|
||||
# Copy the results back to c
|
||||
copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type)
|
||||
core.copy(copy_atom_stg, frgDst, thrDst)
|
||||
|
||||
|
||||
@CuTeDSL.jit(preprocess=False)
|
||||
def _convert(
|
||||
src: core.Tensor,
|
||||
dst: core.Tensor,
|
||||
leading_mode: Constexpr,
|
||||
elem_per_copy: Constexpr,
|
||||
):
|
||||
|
||||
# Step 1. figure proper tv_layout
|
||||
src_ty = src.element_type
|
||||
dst_ty = dst.element_type
|
||||
|
||||
tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1))
|
||||
|
||||
# Step 2. maybe recast from f4 tensor
|
||||
src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout)
|
||||
dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout)
|
||||
src_shape = src.shape
|
||||
# predicate tensor
|
||||
idA = core.make_identity_tensor(src.shape)
|
||||
|
||||
# Step 3. select a proper tiling pattern as (...,TileV, ...)
|
||||
src_cta_tiler = [
|
||||
1,
|
||||
] * core.rank(src.layout)
|
||||
src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...)
|
||||
dst_cta_tiler = [
|
||||
1,
|
||||
] * core.rank(dst.layout)
|
||||
dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...)
|
||||
|
||||
# Step 4. partition input and output tensor by cta tiler.
|
||||
gS = core.zipped_divide(
|
||||
src, tuple(src_cta_tiler)
|
||||
) # ((...,TileV,...),(...,RestV,...))
|
||||
cS = core.zipped_divide(
|
||||
idA, tuple(src_cta_tiler)
|
||||
) # ((...,TileV,...),(...,RestV,...))
|
||||
gD = core.zipped_divide(
|
||||
dst, tuple(dst_cta_tiler)
|
||||
) # ((...,TileV,...),(...,RestV,...))
|
||||
# print(f"{gS.type=}")
|
||||
|
||||
_convert_kernel(
|
||||
gS,
|
||||
gD,
|
||||
cS,
|
||||
src_tv_layout,
|
||||
dst_tv_layout,
|
||||
src_shape,
|
||||
src_ty,
|
||||
dst_ty,
|
||||
).launch(
|
||||
grid=[core.size(gS, mode=[1]), 1, 1],
|
||||
block=[core.size(src_tv_layout, mode=[0]), 1, 1],
|
||||
)
|
||||
|
||||
|
||||
# Converts from src tensor to dst tensor, their logical shape are required to be the same.
|
||||
# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of
|
||||
# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext
|
||||
# needs 32-bits aligned input/output)
|
||||
def convert(src: core.Tensor, dst: core.Tensor):
|
||||
assert len(src.shape) == len(
|
||||
dst.shape
|
||||
), "Shape of src and dst tensors should be the same rank."
|
||||
# find leading mode
|
||||
leading_mode = np.argmin([np.min(s) for s in src.stride])
|
||||
|
||||
elem_per_copy = 2
|
||||
|
||||
if src.element_type.width == 4 or dst.element_type.width == 4:
|
||||
elem_per_copy = 8
|
||||
elif src.element_type.width == 8 or dst.element_type.width == 8:
|
||||
elem_per_copy = 4
|
||||
assert (
|
||||
src.shape[leading_mode] % elem_per_copy == 0
|
||||
and dst.shape[leading_mode] % elem_per_copy == 0
|
||||
)
|
||||
_convert(src, dst, leading_mode, elem_per_copy)
|
||||
|
||||
|
||||
#########################################
|
||||
# Testing utilities
|
||||
#########################################
|
||||
|
||||
|
||||
def sample_pytest(rand_cfg=None):
|
||||
"""
|
||||
Decorator to randomly sample pytest parametrized tests.
|
||||
rand_cfg: Tuple[int, float] - (random_seed, sample_ratio)
|
||||
Sampling is disabled when:
|
||||
- A specific test is selected (via -k or direct test path)
|
||||
- Not running under pytest
|
||||
"""
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
seed, sample_ratio = rand_cfg
|
||||
random.seed(seed)
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ:
|
||||
# Check if test was explicitly selected like ::test_name[param1-param2-...]
|
||||
if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv):
|
||||
# Test was explicitly selected, don't skip
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if random.uniform(0.0, 1.0) > sample_ratio:
|
||||
pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
193
python/CuTeDSL/cutlass/cute/typing.py
Normal file
193
python/CuTeDSL/cutlass/cute/typing.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ForwardRef, Tuple, Union, Any, Type, List
|
||||
|
||||
from cutlass.base_dsl.typing import *
|
||||
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.extras.types as T
|
||||
from cutlass._mlir.dialects.cute import AddressSpace
|
||||
|
||||
|
||||
Int = Union[int, Integer]
|
||||
|
||||
|
||||
ScaledBasis = ForwardRef("ScaledBasis")
|
||||
|
||||
|
||||
IntTuple = Union[Int, Tuple["IntTuple", ...]]
|
||||
Shape = Union[Int, Tuple["Shape", ...]]
|
||||
Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]]
|
||||
Coord = Union[Int, None, Tuple["Coord", ...]]
|
||||
|
||||
|
||||
class Layout(ir.Value):
|
||||
def __init__(self, op_result):
|
||||
super().__init__(op_result)
|
||||
|
||||
def __str__(self): ...
|
||||
|
||||
def get_hier_coord(self, idx) -> Coord:
|
||||
"""Return the (hierarchical) ND logical coordinate corresponding to the linear index"""
|
||||
...
|
||||
|
||||
@property
|
||||
def shape(self, *, loc=None, ip=None) -> Shape: ...
|
||||
|
||||
@property
|
||||
def stride(self, *, loc=None, ip=None) -> Stride: ...
|
||||
|
||||
|
||||
Tile = Union[Int, None, Layout, Tuple["Tile", ...]]
|
||||
|
||||
# XTuple is super set of above types
|
||||
XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
|
||||
|
||||
Tiler = Union[Shape, Layout, Tile]
|
||||
|
||||
|
||||
class Pointer:
|
||||
"""
|
||||
Abstract base class for CuTe jit function and runtime _Pointer
|
||||
"""
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
# Doesn't matter just return a value
|
||||
return [self]
|
||||
|
||||
|
||||
class Tensor(ABC):
|
||||
"""
|
||||
Abstract base class for CuTe jit function and runtime _Tensor
|
||||
|
||||
A CuTe Tensor is iterator with layout
|
||||
|
||||
:Examples:
|
||||
|
||||
Create tensor from torch.tensor with Host Runtime:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> import torch
|
||||
>>> from cutlass.cute.runtime import from_dlpack
|
||||
>>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32))
|
||||
>>> mA.shape
|
||||
(3,)
|
||||
>>> mA.stride
|
||||
(1,)
|
||||
>>> mA.layout
|
||||
(3,):(1,)
|
||||
|
||||
Define JIT function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@cute.jit
|
||||
def add(a: Tensor, b: Tensor, res: Tensor): ...
|
||||
|
||||
Call JIT function from python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> import torch
|
||||
>>> a = torch.tensor([1, 3, 5], dtype=torch.int32)
|
||||
>>> b = torch.tensor([2, 4, 6], dtype=torch.int32)
|
||||
>>> c = torch.zeros([3], dtype=torch.int32)
|
||||
>>> mA = from_dlpack(a)
|
||||
>>> mB = from_dlpack(b)
|
||||
>>> mC = from_dlpack(c)
|
||||
>>> add(mA, mB, mC)
|
||||
>>> c
|
||||
tensor([3, 7, 11], dtype=torch.int32)
|
||||
"""
|
||||
|
||||
def __str__(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ...
|
||||
|
||||
@abstractmethod
|
||||
def __setitem__(self, idx, value): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ...
|
||||
|
||||
@element_type.setter
|
||||
def element_type(self, new_type): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memspace(self) -> AddressSpace: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def iterator(self): ...
|
||||
|
||||
@property
|
||||
def layout(self) -> Union[Layout, "ComposedLayout"]: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape: ...
|
||||
|
||||
def load(self, *, loc=None, ip=None) -> "TensorSSA": ...
|
||||
|
||||
def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
|
||||
|
||||
def mark_layout_dynamic(self, leading_dim: int|None = None) -> "Tensor": ...
|
||||
|
||||
def mark_compact_shape_dynamic(
|
||||
self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1
|
||||
) -> "Tensor": ...
|
||||
|
||||
@abstractmethod
|
||||
def fill(self, value: Numeric) -> None: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Coord",
|
||||
"Numeric",
|
||||
"Integer",
|
||||
"Boolean",
|
||||
"Int8",
|
||||
"Int16",
|
||||
"Int32",
|
||||
"Int64",
|
||||
"Uint8",
|
||||
"Uint16",
|
||||
"Uint32",
|
||||
"Uint64",
|
||||
"Float",
|
||||
"Float16",
|
||||
"BFloat16",
|
||||
"TFloat32",
|
||||
"Float32",
|
||||
"Float64",
|
||||
"Float8E5M2",
|
||||
"Float8E4M3FN",
|
||||
"Float8E4M3B11FNUZ",
|
||||
"Float8E4M3",
|
||||
"Float8E8M0FNU",
|
||||
"Float4E2M1FN",
|
||||
"Float6E2M3FN",
|
||||
"Float6E3M2FN",
|
||||
"IntTuple",
|
||||
"Layout",
|
||||
"Pointer",
|
||||
"Shape",
|
||||
"Stride",
|
||||
"Tensor",
|
||||
"Tile",
|
||||
"Tiler",
|
||||
"XTuple",
|
||||
]
|
||||
32
python/CuTeDSL/cutlass/impl_utils.py
Normal file
32
python/CuTeDSL/cutlass/impl_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
|
||||
def check_value_in(
|
||||
value, possible_values: list, value_description: str, prefix=""
|
||||
) -> None:
|
||||
if value not in possible_values:
|
||||
err_msg = prefix
|
||||
if err_msg != "":
|
||||
err_msg += ": "
|
||||
err_msg += f"invalid {value_description}, got {value}, must be one of {possible_values}"
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None:
|
||||
if not isinstance(ty, type):
|
||||
ty = type(ty)
|
||||
if ty not in possible_types:
|
||||
err_msg = prefix
|
||||
if err_msg != "":
|
||||
err_msg += ": "
|
||||
err_msg += f"invalid type for {type_description}, got {ty}, must be one of {possible_types}"
|
||||
raise TypeError(err_msg)
|
||||
169
python/CuTeDSL/cutlass/torch.py
Normal file
169
python/CuTeDSL/cutlass/torch.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
from cutlass.cute.typing import (
|
||||
Numeric,
|
||||
Boolean,
|
||||
Float,
|
||||
Integer,
|
||||
TFloat32,
|
||||
Float8E4M3B11FNUZ,
|
||||
Float8E4M3FN,
|
||||
Float8E5M2,
|
||||
Float8E8M0FNU,
|
||||
Float4E2M1FN,
|
||||
Tensor,
|
||||
)
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
|
||||
|
||||
def dtype(ty: Type[Numeric]):
|
||||
"""
|
||||
Return the corresponding torch.dtype per the given DSL type
|
||||
"""
|
||||
torch_dtype = getattr(torch, ty.__name__.lower(), None)
|
||||
|
||||
torch_type_map = {
|
||||
Boolean: torch.bool,
|
||||
# TFloat32 is just alias of float32
|
||||
TFloat32: torch.float32,
|
||||
Float8E5M2: torch.float8_e5m2,
|
||||
Float8E4M3FN: torch.float8_e4m3fn,
|
||||
Float8E4M3B11FNUZ: torch.float8_e4m3fnuz,
|
||||
}
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch_type_map.get(ty)
|
||||
|
||||
if torch_dtype is None:
|
||||
raise TypeError(f"{ty} is not supported by torch")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalarInitConfig:
|
||||
"""Configuration for scalar initialization"""
|
||||
|
||||
value: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomInitConfig:
|
||||
"""Configuration for random initialization"""
|
||||
|
||||
min_val: int = -2
|
||||
max_val: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class GaussianInitConfig:
|
||||
"""Configuration for Gaussian initialization"""
|
||||
|
||||
mean: float = 0.0
|
||||
std: float = 1.0
|
||||
scale: float = 1.0
|
||||
|
||||
|
||||
class TensorInitType(Enum):
|
||||
"""Enumeration of tensor initialization types"""
|
||||
|
||||
SKIP = "skip"
|
||||
SCALAR = "scalar"
|
||||
RANDOM = "random"
|
||||
GAUSSIAN = "gaussian"
|
||||
|
||||
|
||||
def create_and_permute_torch_tensor(
|
||||
shape,
|
||||
dtype: "torch.dtype",
|
||||
permute_order=None,
|
||||
init_type: TensorInitType = TensorInitType.RANDOM,
|
||||
init_config: Optional[
|
||||
Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig]
|
||||
] = None,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config
|
||||
"""
|
||||
init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32
|
||||
init_torch_tensor = torch.empty(*shape, dtype=init_dtype)
|
||||
if init_type == TensorInitType.SKIP:
|
||||
assert init_config is None
|
||||
f32_torch_tensor = init_torch_tensor
|
||||
elif init_type == TensorInitType.SCALAR:
|
||||
if init_config is None:
|
||||
init_config = ScalarInitConfig()
|
||||
else:
|
||||
if not isinstance(init_config, ScalarInitConfig):
|
||||
raise ValueError("init_config must be ScalarInitConfig()")
|
||||
f32_torch_tensor = init_torch_tensor.fill_(init_config.value)
|
||||
elif init_type == TensorInitType.RANDOM:
|
||||
if init_config is None:
|
||||
init_config = RandomInitConfig()
|
||||
else:
|
||||
if not isinstance(init_config, RandomInitConfig):
|
||||
raise ValueError("init_config must be RandomInitConfig()")
|
||||
f32_torch_tensor = init_torch_tensor.random_(
|
||||
init_config.min_val, init_config.max_val
|
||||
).to(dtype=torch.float32)
|
||||
elif init_type == TensorInitType.GAUSSIAN:
|
||||
if init_config is None:
|
||||
init_config = GaussianInitConfig()
|
||||
else:
|
||||
if not isinstance(init_config, GaussianInitConfig):
|
||||
raise ValueError("init_config must be GaussianInitConfig()")
|
||||
f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std)
|
||||
f32_torch_tensor = f32_torch_tensor * (1 << init_config.scale)
|
||||
else:
|
||||
raise ValueError(f"Invalid init type: {init_type}")
|
||||
|
||||
if permute_order is not None:
|
||||
f32_torch_tensor = f32_torch_tensor.permute(permute_order)
|
||||
|
||||
dtype_torch_tensor = f32_torch_tensor.to(dtype=dtype)
|
||||
|
||||
return dtype_torch_tensor
|
||||
|
||||
|
||||
def convert_cute_tensor(
|
||||
f32_torch_tensor: "torch.Tensor",
|
||||
cute_tensor: Tensor,
|
||||
dtype: Type[Numeric],
|
||||
is_dynamic_layout: bool = True,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Change the value of the cute tensor to make its value converted from a fp32 torch tensor.
|
||||
Used for fp8 types tensor creatation now.
|
||||
"""
|
||||
# if torch_tensor is on cpu, create a gpu copy
|
||||
if f32_torch_tensor.device.type == "cpu":
|
||||
f32_torch_tensor = f32_torch_tensor.cuda()
|
||||
|
||||
# Fp8 type need explicit type conversion
|
||||
if dtype in {
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Float8E8M0FNU,
|
||||
Float4E2M1FN,
|
||||
}:
|
||||
fp32_cute_tensor = from_dlpack(f32_torch_tensor)
|
||||
if is_dynamic_layout:
|
||||
fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic(
|
||||
f32_torch_tensor.dim_order()[-1]
|
||||
)
|
||||
# Copy and convert from f32 cute tensor to dtype cute tensor
|
||||
cute.testing.convert(fp32_cute_tensor, cute_tensor)
|
||||
return cute_tensor
|
||||
9
python/CuTeDSL/cutlass/utils/README.md
Normal file
9
python/CuTeDSL/cutlass/utils/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Utilities
|
||||
|
||||
This folder contains various utilties for kernel authoring. Specifically, the implementation of the
|
||||
followings can be considered experimental and subject to breaking changes:
|
||||
|
||||
- static persistent tile scheduler defined in [`static_persistent_tile_scheduler.py`](./static_persistent_tile_scheduler.py)
|
||||
- pipeline abstractions defined in [`pipeline.py`](./pipeline.py)
|
||||
- grouped GEMM utilties defined [`grouped_gemm_tile_scheduler_helper.py`](./grouped_gemm_tile_scheduler_helper.py)
|
||||
and [`tensormap_manager.py`](./tensormap_manager.py)
|
||||
78
python/CuTeDSL/cutlass/utils/__init__.py
Normal file
78
python/CuTeDSL/cutlass/utils/__init__.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .static_persistent_tile_scheduler import (
|
||||
WorkTileInfo,
|
||||
PersistentTileSchedulerParams,
|
||||
StaticPersistentTileScheduler,
|
||||
)
|
||||
|
||||
from .pipeline import (
|
||||
Agent,
|
||||
CooperativeGroup,
|
||||
PipelineUserType,
|
||||
PipelineState,
|
||||
make_pipeline_state,
|
||||
PipelineAsync,
|
||||
PipelineTmaAsync,
|
||||
PipelineTmaUmma,
|
||||
PipelineUmmaAsync,
|
||||
PipelineTmaStore,
|
||||
pipeline_init_wait,
|
||||
)
|
||||
|
||||
from .hardware_info import (
|
||||
HardwareInfo,
|
||||
)
|
||||
|
||||
from .blackwell_helpers import (
|
||||
compute_epilogue_tile_shape,
|
||||
get_smem_store_op,
|
||||
get_tmem_load_op,
|
||||
get_num_tmem_alloc_cols,
|
||||
make_smem_layout_a,
|
||||
make_smem_layout_b,
|
||||
make_smem_layout_epi,
|
||||
make_trivial_tiled_mma,
|
||||
)
|
||||
|
||||
from .hopper_helpers import (
|
||||
sm90_get_smem_store_op,
|
||||
)
|
||||
|
||||
from .grouped_gemm_tile_scheduler_helper import (
|
||||
GroupSearchResult,
|
||||
GroupedGemmGroupSearchState,
|
||||
GroupedGemmTileSchedulerHelper,
|
||||
create_initial_search_state,
|
||||
)
|
||||
|
||||
from .tensormap_manager import (
|
||||
TensorMapUpdateMode,
|
||||
TensorMapManager,
|
||||
)
|
||||
|
||||
from .smem_allocator import SmemAllocator
|
||||
|
||||
from .layout import LayoutEnum
|
||||
|
||||
__all__ = [
|
||||
"WorkTileInfo",
|
||||
"PersistentTileSchedulerParams",
|
||||
"StaticPersistentTileScheduler",
|
||||
"TensorMapUpdateMode",
|
||||
"TensorMapManager",
|
||||
"GroupSearchResult",
|
||||
"GroupedGemmGroupSearchState",
|
||||
"create_initial_search_state",
|
||||
"GroupedGemmTileSchedulerHelper",
|
||||
"HardwareInfo",
|
||||
]
|
||||
26
python/CuTeDSL/cutlass/utils/ampere_helpers.py
Normal file
26
python/CuTeDSL/cutlass/utils/ampere_helpers.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SmemCapacity(Enum):
|
||||
SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024
|
||||
SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
|
||||
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value,
|
||||
"sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value,
|
||||
"sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
910
python/CuTeDSL/cutlass/utils/blackwell_helpers.py
Normal file
910
python/CuTeDSL/cutlass/utils/blackwell_helpers.py
Normal file
@@ -0,0 +1,910 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from enum import Enum
|
||||
from math import log2, ceil
|
||||
from typing import List, Type, Union, Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Float16,
|
||||
BFloat16,
|
||||
TFloat32,
|
||||
Float32,
|
||||
Uint8,
|
||||
Int8,
|
||||
Float8E4M3FN,
|
||||
Float8E5M2,
|
||||
Numeric,
|
||||
NumericMeta,
|
||||
dsl_user_op,
|
||||
)
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu.common import CopyUniversalOp
|
||||
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp
|
||||
from cutlass.cute.nvgpu.tcgen05 import (
|
||||
MmaF16BF16Op,
|
||||
MmaTF32Op,
|
||||
MmaI8Op,
|
||||
MmaFP8Op,
|
||||
OperandSource,
|
||||
OperandMajorMode,
|
||||
CtaGroup,
|
||||
Ld16x64bOp,
|
||||
Ld16x128bOp,
|
||||
Ld16x256bOp,
|
||||
Ld16x32bx2Op,
|
||||
Ld32x32bOp,
|
||||
Repetition,
|
||||
Pack,
|
||||
find_tmem_tensor_col_offset,
|
||||
SmemLayoutAtomKind,
|
||||
make_smem_layout_atom,
|
||||
tile_to_mma_shape,
|
||||
is_tmem_load,
|
||||
get_tmem_copy_properties,
|
||||
)
|
||||
from cutlass.utils.layout import LayoutEnum
|
||||
|
||||
@dsl_user_op
|
||||
def compute_epilogue_tile_shape(
|
||||
cta_tile_shape: cute.Shape,
|
||||
use_2cta_instrs: bool,
|
||||
layout_d: LayoutEnum,
|
||||
elem_ty_d: Type[Numeric],
|
||||
*,
|
||||
layout_c: LayoutEnum = None,
|
||||
elem_ty_c: Union[Type[Numeric], None] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Tile:
|
||||
"""Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
||||
|
||||
:param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile, where
|
||||
cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1]
|
||||
corresponds to the width (N) of the tile.
|
||||
:type cta_tile_shape: cute.Shape
|
||||
:param use_2cta_instrs: A flag indicating whether the configuration is for a 2SM setup.
|
||||
:type use_2cta_instrs: bool
|
||||
:param layout_d: The layout enum of the output tensor D.
|
||||
:type layout_d: LayoutEnum
|
||||
:param elem_ty_d: The element type of output tensor D.
|
||||
:type elem_ty_d: Type[Numeric]
|
||||
:param layout_c: The layout enum of the input tensor C. Defaults to None.
|
||||
:type layout_c: LayoutEnum, optional
|
||||
:param elem_ty_c: The element type for input tensor C. Defaults to None.
|
||||
:type elem_ty_c: Union[Type[Numeric], None], optional
|
||||
|
||||
:return: Returns epilog tiler, which is used in subsequent epilog partitions.
|
||||
:rtype: cute.Tile
|
||||
|
||||
:raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions.
|
||||
"""
|
||||
|
||||
def validate_type(ty, ty_name):
|
||||
if not isinstance(ty, NumericMeta):
|
||||
raise TypeError(f"{ty_name} must be Numeric, but got {ty}")
|
||||
|
||||
validate_type(elem_ty_d, "elem_ty_d")
|
||||
if elem_ty_c is not None:
|
||||
validate_type(elem_ty_c, "elem_ty_c")
|
||||
|
||||
cta_m, cta_n = cta_tile_shape[:2]
|
||||
(warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1)
|
||||
disable_source = elem_ty_c == None
|
||||
max_bits = (
|
||||
elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width)
|
||||
)
|
||||
|
||||
dp_full = 32
|
||||
tile_m = min(cta_m, dp_full * warp_m)
|
||||
n_perf = 0
|
||||
if disable_source:
|
||||
if max_bits == 4:
|
||||
compute_elts = 8192
|
||||
else:
|
||||
compute_elts = 4096
|
||||
n_perf = compute_elts // tile_m
|
||||
else:
|
||||
if max_bits == 32:
|
||||
n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32
|
||||
elif max_bits == 16:
|
||||
n_perf = 32 if cta_n <= 128 else 64
|
||||
else:
|
||||
n_perf = 64
|
||||
|
||||
d_is_m_major = layout_d.is_m_major_c()
|
||||
c_is_m_major = True if layout_c is None else layout_c.is_m_major_c()
|
||||
|
||||
n_min_d = (
|
||||
8 * warp_n
|
||||
if d_is_m_major
|
||||
else (128 * warp_n if elem_ty_d.width == 6 else 128 // elem_ty_d.width * warp_n)
|
||||
)
|
||||
n_min_c = (
|
||||
8 * warp_n
|
||||
if (c_is_m_major or disable_source)
|
||||
else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n)
|
||||
)
|
||||
tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d))
|
||||
|
||||
if cta_n < n_min_c or cta_n < n_min_d:
|
||||
raise ValueError(f"CTA tile too small: {cta_tile_shape=}")
|
||||
|
||||
# stride by tmem warp layout and return a by-mode tiler
|
||||
tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip)
|
||||
tile_n_layout = cute.make_layout(
|
||||
(tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip
|
||||
)
|
||||
return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip))
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_smem_store_op(
|
||||
layout_d: LayoutEnum,
|
||||
elem_ty_d: Type[Numeric],
|
||||
elem_ty_acc: Type[Numeric],
|
||||
tiled_tmem_load: cute.TiledCopy,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.CopyAtom:
|
||||
"""Selects the largest vectorized smem store atom available subject to
|
||||
constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership.
|
||||
|
||||
:param layout_d: The layout enum of the output tensor D.
|
||||
:type layout_d: LayoutEnum
|
||||
:param elem_ty_d: The element type for output tensor D.
|
||||
:type elem_ty_d: Type[Numeric]
|
||||
:param elem_ty_acc: The element type for accumulator.
|
||||
:type elem_ty_acc: Type[Numeric]
|
||||
:param tiled_tmem_load: An instance of TiledCopy that represents the tmem load operation.
|
||||
:type tiled_tmem_load: cute.TiledCopy
|
||||
|
||||
:return: Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
|
||||
:rtype: cute.CopyAtom
|
||||
"""
|
||||
|
||||
def validate_type(ty, ty_name):
|
||||
if not isinstance(ty, NumericMeta):
|
||||
raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")
|
||||
|
||||
validate_type(elem_ty_d, "elem_ty_d")
|
||||
validate_type(elem_ty_acc, "elem_ty_acc")
|
||||
|
||||
is_m_major = layout_d.is_m_major_c()
|
||||
is_n_major = layout_d.is_n_major_c()
|
||||
|
||||
if not is_tmem_load(tiled_tmem_load):
|
||||
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
|
||||
|
||||
num_dp, num_bits, num_rep, pack = get_tmem_copy_properties(tiled_tmem_load)
|
||||
|
||||
use_stmatrix_m8n8_4x = (
|
||||
all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 32,
|
||||
is_n_major,
|
||||
num_dp == 16,
|
||||
num_bits == 128,
|
||||
num_rep in (2, 4, 8, 16, 32, 64),
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
or all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 16,
|
||||
num_dp == 16,
|
||||
num_bits == 256,
|
||||
num_rep in (2, 4, 8, 16, 32),
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
or all(
|
||||
[
|
||||
elem_ty_acc.width == 16,
|
||||
elem_ty_d.width == 16,
|
||||
num_dp == 16,
|
||||
num_bits == 128,
|
||||
num_rep in (2, 4, 8, 16, 32, 64),
|
||||
pack == Pack.PACK_16b_IN_32b,
|
||||
]
|
||||
)
|
||||
)
|
||||
use_stmatrix_m16n8_4x = all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 8,
|
||||
is_m_major,
|
||||
num_dp == 16,
|
||||
num_bits == 256,
|
||||
num_rep in (4, 8, 16, 32),
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
use_stmatrix_m8n8_2x = (
|
||||
all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 32,
|
||||
is_n_major,
|
||||
num_dp == 16,
|
||||
num_bits == 128,
|
||||
num_rep == 1,
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
or all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 16,
|
||||
num_dp == 16,
|
||||
num_bits == 256,
|
||||
num_rep == 1,
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
or all(
|
||||
[
|
||||
elem_ty_acc.width == 16,
|
||||
elem_ty_d.width == 16,
|
||||
num_dp == 16,
|
||||
num_bits == 128,
|
||||
num_rep == 1,
|
||||
pack == Pack.PACK_16b_IN_32b,
|
||||
]
|
||||
)
|
||||
)
|
||||
use_stmatrix_m16n8_2x = all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 8,
|
||||
is_m_major,
|
||||
num_dp == 16,
|
||||
num_bits == 256,
|
||||
num_rep == 2,
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
use_stmatrix_m16n8_1x = all(
|
||||
[
|
||||
elem_ty_acc.width == 32,
|
||||
elem_ty_d.width == 8,
|
||||
is_m_major,
|
||||
num_dp == 16,
|
||||
num_bits == 256,
|
||||
num_rep == 1,
|
||||
pack == Pack.NONE,
|
||||
]
|
||||
)
|
||||
|
||||
if use_stmatrix_m8n8_4x:
|
||||
op = StMatrix8x8x16bOp(is_m_major, 4)
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
elif use_stmatrix_m8n8_2x:
|
||||
op = StMatrix8x8x16bOp(is_m_major, 2)
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
elif use_stmatrix_m16n8_4x:
|
||||
op = StMatrix16x8x8bOp(4)
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
elif use_stmatrix_m16n8_2x:
|
||||
op = StMatrix16x8x8bOp(2)
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
elif use_stmatrix_m16n8_1x:
|
||||
op = StMatrix16x8x8bOp(1)
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
else:
|
||||
op = CopyUniversalOp()
|
||||
return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def get_tmem_load_op(
|
||||
cta_tile_shape: cute.Shape,
|
||||
layout_d: LayoutEnum,
|
||||
elem_ty_d: Type[Numeric],
|
||||
elem_ty_acc: Type[Numeric],
|
||||
epi_tile: cute.Tile,
|
||||
use_2cta_instrs: bool,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.CopyAtom:
|
||||
"""Finds a performant TMEM_LOAD copy op for the selected epilogue
|
||||
tile (epi_tile), element types, and tcgen05.mma instruction used.
|
||||
|
||||
:param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile.
|
||||
:type cta_tile_shape: cute.Shape
|
||||
:param layout_d: The layout enum of the output tensor D.
|
||||
:type layout_d: LayoutEnum
|
||||
:param elem_ty_d: The element type for output tensor D.
|
||||
:type elem_ty_d: Type[Numeric]
|
||||
:param elem_ty_acc: The element type for accumulation.
|
||||
:type elem_ty_acc: Type[Numeric]
|
||||
:param epi_tile: The epilogue tile configuration.
|
||||
:type epi_tile: cute.Tile
|
||||
:param use_2cta_instrs: A flag indicating whether the configuration is for 2 SMs.
|
||||
:type use_2cta_instrs: bool
|
||||
|
||||
:return: An instance of Sm100TmemLoad with the computed configuration.
|
||||
:rtype: cute.CopyAtom
|
||||
|
||||
:raises ValueError: If the function cannot handle the given combination of accumulation
|
||||
and dimension types, or if it cannot determine the appropriate configuration based on
|
||||
the input parameters.
|
||||
"""
|
||||
is_m_major = layout_d.is_m_major_c()
|
||||
|
||||
acc_bits = elem_ty_acc.width
|
||||
d_bits = elem_ty_d.width
|
||||
|
||||
tmem_warp_shape_mn = (
|
||||
(2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1)
|
||||
)
|
||||
epilog_tile_shape_mn = cute.product_each(
|
||||
cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
epilog_warp_tile_shape_mn = cute.shape_div(
|
||||
epilog_tile_shape_mn, tmem_warp_shape_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
num_dp = cute.size(epilog_warp_tile_shape_mn[0], loc=loc, ip=ip)
|
||||
if num_dp not in {16, 32}:
|
||||
raise ValueError("Cta tile and 2sm config does not generate correct num dp.")
|
||||
|
||||
num_col_bits = cute.size(epilog_warp_tile_shape_mn[1], loc=loc, ip=ip) * acc_bits
|
||||
|
||||
tmem_dp = 0
|
||||
tmem_bit = 0
|
||||
tmem_rep = 0
|
||||
tmem_pack16b = False
|
||||
if acc_bits == 32 and d_bits == 32:
|
||||
if num_dp == 16:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 128
|
||||
else:
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
elif acc_bits == 32 and d_bits == 16:
|
||||
if num_dp == 16:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
elif acc_bits == 32 and d_bits == 8:
|
||||
if num_dp == 16:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 32
|
||||
else:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 256
|
||||
else:
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
elif acc_bits == 16 and d_bits == 16:
|
||||
tmem_pack16b = True
|
||||
if num_dp == 16:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 128
|
||||
else:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 128
|
||||
else:
|
||||
if is_m_major:
|
||||
tmem_dp = 16
|
||||
tmem_bit = 128
|
||||
else:
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
elif acc_bits == 32 and d_bits == 6:
|
||||
if not num_dp == 32:
|
||||
raise ValueError("Num dp must be 32.")
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
elif acc_bits == 32 and d_bits == 4:
|
||||
if not num_dp == 32:
|
||||
raise ValueError("Num dp must be 32.")
|
||||
tmem_dp = 32
|
||||
tmem_bit = 32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not handle acc/d type combination: {elem_ty_acc=}, {elem_ty_d=}"
|
||||
)
|
||||
|
||||
num_bit_div = tmem_bit
|
||||
if tmem_dp == 16 and tmem_bit == 32:
|
||||
num_bit_div = 64
|
||||
|
||||
if (num_col_bits % (num_bit_div * 128) == 0) and (
|
||||
(tmem_dp == 16 and tmem_bit == 64)
|
||||
or (tmem_dp == 16 and tmem_bit == 32)
|
||||
or (tmem_dp == 32 and tmem_bit == 32)
|
||||
):
|
||||
tmem_rep = 128
|
||||
elif (num_col_bits % (num_bit_div * 64) == 0) and (
|
||||
(tmem_dp == 16 and tmem_bit == 128)
|
||||
or (tmem_dp == 16 and tmem_bit == 64)
|
||||
or (tmem_dp == 16 and tmem_bit == 32)
|
||||
or (tmem_dp == 32 and tmem_bit == 32)
|
||||
):
|
||||
tmem_rep = 64
|
||||
elif num_col_bits % (num_bit_div * 32) == 0:
|
||||
tmem_rep = 32
|
||||
elif num_col_bits % (num_bit_div * 16) == 0:
|
||||
tmem_rep = 16
|
||||
elif num_col_bits % (num_bit_div * 8) == 0:
|
||||
tmem_rep = 8
|
||||
elif num_col_bits % (num_bit_div * 4) == 0:
|
||||
tmem_rep = 4
|
||||
elif num_col_bits % (num_bit_div * 2) == 0:
|
||||
tmem_rep = 2
|
||||
elif num_col_bits % (num_bit_div * 1) == 0:
|
||||
tmem_rep = 1
|
||||
else:
|
||||
raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.")
|
||||
|
||||
if tmem_dp == 16 and tmem_bit == 64:
|
||||
op = Ld16x64bOp(
|
||||
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
|
||||
)
|
||||
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
|
||||
elif tmem_dp == 16 and tmem_bit == 128:
|
||||
op = Ld16x128bOp(
|
||||
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
|
||||
)
|
||||
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
|
||||
elif tmem_dp == 16 and tmem_bit == 256:
|
||||
op = Ld16x256bOp(
|
||||
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
|
||||
)
|
||||
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
|
||||
elif tmem_dp == 16 and tmem_bit == 32:
|
||||
op = Ld16x32bx2Op(
|
||||
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
|
||||
)
|
||||
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
|
||||
|
||||
elif tmem_dp == 32 and tmem_bit == 32:
|
||||
op = Ld32x32bOp(
|
||||
Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE
|
||||
)
|
||||
return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def get_num_tmem_alloc_cols(
|
||||
tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], rounding=True
|
||||
) -> int:
|
||||
"""Get the total number of TMEM allocation columns for the given TMEM tensors.
|
||||
|
||||
:param tmem_tensors: The TMEM tensors to get the number of allocation columns for.
|
||||
:type tmem_tensors: Union[cute.Tensor, List[cute.Tensor]]
|
||||
:param rounding: Whether to round up the number of allocation columns to the nearest power of 2.
|
||||
:type rounding: bool
|
||||
|
||||
:return: The total number of TMEM allocation columns.
|
||||
:rtype: int
|
||||
|
||||
:raises ValueError: If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32.
|
||||
"""
|
||||
# Turn tmem_tensors into a list
|
||||
if isinstance(tmem_tensors, cute.Tensor):
|
||||
tmem_tensors = [tmem_tensors]
|
||||
|
||||
# For each tensor in tmem_tensors, find the tmem_tensor_col_offset
|
||||
num_tmem_alloc_cols_per_tensor = [
|
||||
find_tmem_tensor_col_offset(t) for t in tmem_tensors
|
||||
]
|
||||
|
||||
# Sum up the num_tmem_alloc_cols_per_tensor
|
||||
num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor)
|
||||
|
||||
# Round up num_tmem_cols_total to the nearest power of 2
|
||||
if rounding:
|
||||
num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols))
|
||||
|
||||
# Validate the number of TMEM allocation columns
|
||||
SM100_TMEM_CAPACITY_COLUMNS = 512
|
||||
SM100_TMEM_MIN_ALLOC_COLUMNS = 32
|
||||
if (
|
||||
num_tmem_alloc_cols > SM100_TMEM_CAPACITY_COLUMNS
|
||||
or num_tmem_alloc_cols < SM100_TMEM_MIN_ALLOC_COLUMNS
|
||||
):
|
||||
raise ValueError(
|
||||
f"TMEM allocation columns {num_tmem_alloc_cols} exceeds the maximum capacity of {SM100_TMEM_CAPACITY_COLUMNS} or less than {SM100_TMEM_MIN_ALLOC_COLUMNS}"
|
||||
)
|
||||
return num_tmem_alloc_cols
|
||||
|
||||
|
||||
def get_smem_layout_atom_ab(
|
||||
major_mode: OperandMajorMode,
|
||||
element_type: Type[Numeric],
|
||||
smem_shape_mn_k: Tuple[int, int],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> SmemLayoutAtomKind:
|
||||
"""Simple heuristics to select the optimal SMEM layout atom based on the
|
||||
majorness, the data type, and the major mode size.
|
||||
|
||||
:param major_mode: The major mode for the SMEM tensor is K major.
|
||||
:type major_mode: OperandMajorMode
|
||||
:param element_type: The element type for the SMEM tensor.
|
||||
:type element_type: Type[Numeric]
|
||||
:param smem_shape_mn_k: The shape of the SMEM tensor.
|
||||
:type smem_shape_mn_k: Tuple[int, int]
|
||||
|
||||
:return: The SMEM layout atom kind
|
||||
:rtype: SmemLayoutAtomKind
|
||||
"""
|
||||
is_k_major = major_mode == OperandMajorMode.K
|
||||
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
|
||||
|
||||
assert major_mode_size % 8 == 0
|
||||
sw128_num_contiguous_bits = 1024
|
||||
sw64_num_contiguous_bits = 512
|
||||
sw32_num_contiguous_bits = 256
|
||||
inter_num_contiguous_bits = 128
|
||||
major_mode_size_bits = major_mode_size * element_type.width
|
||||
assert major_mode_size_bits % inter_num_contiguous_bits == 0
|
||||
|
||||
if not is_k_major:
|
||||
if (element_type.width == 32) and (
|
||||
major_mode_size_bits % sw128_num_contiguous_bits == 0
|
||||
):
|
||||
return SmemLayoutAtomKind.MN_SW128_32B
|
||||
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.MN_SW128
|
||||
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.MN_SW64
|
||||
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.MN_SW32
|
||||
return SmemLayoutAtomKind.MN_INTER
|
||||
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.K_SW128
|
||||
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.K_SW64
|
||||
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
|
||||
return SmemLayoutAtomKind.K_SW32
|
||||
return SmemLayoutAtomKind.K_INTER
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_a(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
a_dtype: Type[Numeric],
|
||||
num_stages: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[cute.Layout, cute.ComposedLayout]:
|
||||
"""This function helps with:
|
||||
1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler.
|
||||
2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size.
|
||||
3. cute.Tile the SMEM layout atom to the MMA tile shape.
|
||||
4. Stage the SMEM layout based on the number of stages.
|
||||
|
||||
:param tiled_mma: The tiled MMA used to partition tensor A
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The MMA tile shape
|
||||
:type mma_tiler_mnk: cute.cute.Tile
|
||||
:param a_dtype: The element type for tensor A
|
||||
:type a_dtype: Type[Numeric]
|
||||
:param num_stages: The number of pipeline stages for tensor A
|
||||
:type num_stages: int
|
||||
|
||||
:return: SMEM layout for tensor A
|
||||
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
||||
"""
|
||||
|
||||
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
|
||||
a_smem_shape = tiled_mma.partition_shape_A(
|
||||
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
|
||||
)
|
||||
a_smem_shape_mn_k = (
|
||||
cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
|
||||
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
|
||||
)
|
||||
a_smem_layout_atom = make_smem_layout_atom(
|
||||
get_smem_layout_atom_ab(
|
||||
tiled_mma.op.a_major_mode,
|
||||
a_dtype,
|
||||
a_smem_shape_mn_k,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
),
|
||||
a_dtype,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
a_smem_layout_staged = tile_to_mma_shape(
|
||||
a_smem_layout_atom,
|
||||
cute.append(a_smem_shape, num_stages, loc=loc, ip=ip),
|
||||
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return a_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_b(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
b_dtype: Type[Numeric],
|
||||
num_stages: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[cute.Layout, cute.ComposedLayout]:
|
||||
"""This function helps:
|
||||
1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler.
|
||||
2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size.
|
||||
3. cute.Tile the SMEM layout atom to the MMA tile shape.
|
||||
4. Stage the SMEM layout based on the number of stages.
|
||||
|
||||
:param tiled_mma: The tiled MMA which is used to partition the B tensor.
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The MMA tile shape.
|
||||
:type mma_tiler_mnk: cute.cute.Tile
|
||||
:param b_dtype: The element type for the B tensor.
|
||||
:type b_dtype: Type[Numeric]
|
||||
:param num_stages: The stage of the B tensor.
|
||||
:type num_stages: int
|
||||
|
||||
:return: SMEM layout for the B tensor.
|
||||
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
||||
"""
|
||||
|
||||
is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K
|
||||
b_smem_shape = tiled_mma.partition_shape_B(
|
||||
cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip)
|
||||
)
|
||||
b_smem_shape_nk = (
|
||||
cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1],
|
||||
cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2],
|
||||
)
|
||||
b_smem_layout_atom = make_smem_layout_atom(
|
||||
get_smem_layout_atom_ab(
|
||||
tiled_mma.op.b_major_mode,
|
||||
b_dtype,
|
||||
b_smem_shape_nk,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
),
|
||||
b_dtype,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
b_smem_layout_staged = tile_to_mma_shape(
|
||||
b_smem_layout_atom,
|
||||
cute.append(b_smem_shape, num_stages, loc=loc, ip=ip),
|
||||
order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
return b_smem_layout_staged
|
||||
|
||||
@dsl_user_op
|
||||
def get_smem_layout_atom_epi(
|
||||
layout: LayoutEnum,
|
||||
element_type: Type[Numeric],
|
||||
epi_tile: cute.Tile,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> SmemLayoutAtomKind:
|
||||
"""Simple heuristics to select the optimal SMEM layout atom for epilog tensors.
|
||||
|
||||
:param layout: The layout enum for the SMEM tensor.
|
||||
:type layout: LayoutEnum
|
||||
:param element_type: The element type for the SMEM tensor.
|
||||
:type element_type: Type[Numeric]
|
||||
:param epi_tile: The epilogue tile shape.
|
||||
:type epi_tile: cute.Tile
|
||||
|
||||
:return: The SMEM layout atom kind
|
||||
:rtype: SmemLayoutAtomKind
|
||||
"""
|
||||
# Get the max contiguous tile usable by TMA
|
||||
tma_shape = tuple(
|
||||
(
|
||||
# assumes get<0>(epi_tile) is coalesced and unit stride
|
||||
cute.coalesce(cute.right_inverse(x, loc=loc, ip=ip), loc=loc, ip=ip).shape
|
||||
if isinstance(x, cute.Layout)
|
||||
else x
|
||||
)
|
||||
for x in epi_tile
|
||||
)
|
||||
|
||||
if layout.is_m_major_c():
|
||||
# ColMajor C/D (M-major)
|
||||
return get_smem_layout_atom_ab(
|
||||
OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
# RowMajor C/D (N-major)
|
||||
return get_smem_layout_atom_ab(
|
||||
OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_epi(
|
||||
epi_dtype: Type[Numeric],
|
||||
epi_layout: LayoutEnum,
|
||||
epi_tile: cute.Tile,
|
||||
epi_stage: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[cute.Layout, cute.ComposedLayout]:
|
||||
"""This function helps:
|
||||
1. Select the heuristic SMEM layout atom based on the epilog tile shape,
|
||||
the epilog tensor's majorness, and the element type.
|
||||
2. cute.Tile the SMEM layout atom to the epilog tile shape.
|
||||
3. Stage the SMEM layout based on the number of stages.
|
||||
|
||||
:param epi_dtype: The element type for the epilog tensor.
|
||||
:type epi_dtype: Type[Numeric]
|
||||
:param epi_layout: The layout enum for the epilog tensor.
|
||||
:type epi_layout: LayoutEnum
|
||||
:param epi_tile: The epilogue tile shape.
|
||||
:type epi_tile: cute.cute.Tile
|
||||
:param epi_stage: The stage of the epilog tensor.
|
||||
:type epi_stage: int
|
||||
|
||||
:return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog)
|
||||
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
||||
"""
|
||||
|
||||
epilog_shape = cute.product_each(
|
||||
cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
c_smem_layout_atom = make_smem_layout_atom(
|
||||
get_smem_layout_atom_epi(
|
||||
epi_layout,
|
||||
epi_dtype,
|
||||
epi_tile,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
),
|
||||
epi_dtype,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
epi_smem_layout_staged = cute.tile_to_shape(
|
||||
c_smem_layout_atom,
|
||||
cute.append(epilog_shape, epi_stage, loc=loc, ip=ip),
|
||||
order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
return epi_smem_layout_staged
|
||||
|
||||
|
||||
class SmemCapacity(Enum):
|
||||
SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024
|
||||
|
||||
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value,
|
||||
"sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
@dsl_user_op
|
||||
def make_trivial_tiled_mma(
|
||||
ab_dtype: Type[Numeric],
|
||||
a_leading_mode: OperandMajorMode,
|
||||
b_leading_mode: OperandMajorMode,
|
||||
acc_dtype: Type[Numeric],
|
||||
cta_group: CtaGroup,
|
||||
mma_tiler_mn: Tuple[int, int],
|
||||
a_source: OperandSource = OperandSource.SMEM,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.TiledMma:
|
||||
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
|
||||
By default, the MMA atom is created with SMEM operand source for A.
|
||||
|
||||
:param ab_dtype: Data type of operands A and B.
|
||||
:type ab_dtype: type[Numeric]
|
||||
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
|
||||
:type a_leading_mode: tcgen05.OperandMajorMode
|
||||
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
|
||||
:type b_leading_mode: tcgen05.OperandMajorMode
|
||||
:param acc_dtype: Data type of the accumulator.
|
||||
:type acc_dtype: type[Numeric]
|
||||
:param cta_group: The CTA group to use.
|
||||
:type cta_group: tcgen05.CtaGroup
|
||||
:param mma_tiler_mn: The shape (M, N, K) of the MMA tiler.
|
||||
:type mma_tiler_mn: Tuple[int, int]
|
||||
:param a_source: The source of operand A (SMEM by default or TMEM).
|
||||
:type a_source: OperandSource
|
||||
|
||||
:return: A tiled MMA atom.
|
||||
:rtype: cute.TiledMma
|
||||
|
||||
:raises TypeError: If the data type is not supported.
|
||||
"""
|
||||
|
||||
if ab_dtype in {Float16, BFloat16}:
|
||||
mma_op = MmaF16BF16Op(
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
(*mma_tiler_mn, 16),
|
||||
cta_group,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
elif ab_dtype in {TFloat32, Float32}:
|
||||
mma_op = MmaTF32Op(
|
||||
(*mma_tiler_mn, 8),
|
||||
cta_group,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
elif ab_dtype in {
|
||||
Uint8,
|
||||
Int8,
|
||||
}:
|
||||
mma_op = MmaI8Op(
|
||||
ab_dtype,
|
||||
(*mma_tiler_mn, 32),
|
||||
cta_group,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
elif ab_dtype in {Float8E4M3FN, Float8E5M2}:
|
||||
mma_op = MmaFP8Op(
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
(*mma_tiler_mn, 32),
|
||||
cta_group,
|
||||
a_source,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"unsupported ab_dtype, got {ab_dtype}")
|
||||
|
||||
return cute.make_tiled_mma(cute.make_mma_atom(mma_op))
|
||||
@@ -0,0 +1,466 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams
|
||||
|
||||
|
||||
class GroupSearchResult:
|
||||
"""
|
||||
The result of the group search for grouped gemm.
|
||||
|
||||
:param group_idx: The result group index
|
||||
:type group_idx: Int32
|
||||
:param cta_tile_idx_m: CTA tile index along M dimension after rasterization
|
||||
:type cta_tile_idx_m: Int32
|
||||
:param cta_tile_idx_n: CTA tile index along N dimension after rasterization
|
||||
:type cta_tile_idx_n: Int32
|
||||
:param problem_shape_m: The M dimension of the gemm problem
|
||||
:type problem_shape_m: Int32
|
||||
:param problem_shape_n: The N dimension of the gemm problem
|
||||
:type problem_shape_n: Int32
|
||||
:param problem_shape_k: The K dimension of the gemm problem
|
||||
:type problem_shape_k: Int32
|
||||
:param cta_tile_count_k: Number of tiles along K dimension
|
||||
:type cta_tile_count_k: Int32
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_idx: Int32,
|
||||
cta_tile_idx_m: Int32,
|
||||
cta_tile_idx_n: Int32,
|
||||
problem_shape_m: Int32,
|
||||
problem_shape_n: Int32,
|
||||
problem_shape_k: Int32,
|
||||
cta_tile_count_k: Int32,
|
||||
) -> None:
|
||||
self.group_idx = group_idx
|
||||
self.cta_tile_idx_m = cta_tile_idx_m
|
||||
self.cta_tile_idx_n = cta_tile_idx_n
|
||||
self.problem_shape_m = problem_shape_m
|
||||
self.problem_shape_n = problem_shape_n
|
||||
self.problem_shape_k = problem_shape_k
|
||||
self.cta_tile_count_k = cta_tile_count_k
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = extract_mlir_values(self.group_idx)
|
||||
values.extend(extract_mlir_values(self.cta_tile_idx_m))
|
||||
values.extend(extract_mlir_values(self.cta_tile_idx_n))
|
||||
values.extend(extract_mlir_values(self.problem_shape_m))
|
||||
values.extend(extract_mlir_values(self.problem_shape_n))
|
||||
values.extend(extract_mlir_values(self.problem_shape_k))
|
||||
values.extend(extract_mlir_values(self.cta_tile_count_k))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "GroupSearchResult":
|
||||
assert len(values) == 7
|
||||
return GroupSearchResult(*tuple(values))
|
||||
|
||||
|
||||
class GroupedGemmGroupSearchState:
|
||||
"""
|
||||
The state of group index search for grouped gemm.
|
||||
|
||||
The state will be initialized once and updated in every round of group index search.
|
||||
|
||||
:param start_group_idx: The group idx to start the search with
|
||||
:type start_group_idx: Int32
|
||||
:param tile_count_prev_group: Number of tiles before the matched group
|
||||
:type tile_count_prev_group: Int32
|
||||
:param tile_count_searched: Number of tiles we have searched. When the matched group is found,
|
||||
it records the number of tiles including the matched group
|
||||
:type tile_count_searched: Int32
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_group_idx: Int32,
|
||||
tile_count_prev_group: Int32,
|
||||
tile_count_searched: Int32,
|
||||
) -> None:
|
||||
self.start_group_idx = start_group_idx
|
||||
self.tile_count_prev_group = tile_count_prev_group
|
||||
self.tile_count_searched = tile_count_searched
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = extract_mlir_values(self.start_group_idx)
|
||||
values.extend(extract_mlir_values(self.tile_count_prev_group))
|
||||
values.extend(extract_mlir_values(self.tile_count_searched))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: List[ir.Value]
|
||||
) -> "GroupedGemmGroupSearchState":
|
||||
start_group_idx = new_from_mlir_values(self.start_group_idx, [values[0]])
|
||||
tile_count_prev_group = new_from_mlir_values(
|
||||
self.tile_count_prev_group, [values[1]]
|
||||
)
|
||||
tile_count_searched = new_from_mlir_values(
|
||||
self.tile_count_searched, [values[2]]
|
||||
)
|
||||
return GroupedGemmGroupSearchState(
|
||||
start_group_idx, tile_count_prev_group, tile_count_searched
|
||||
)
|
||||
|
||||
|
||||
def create_initial_search_state() -> GroupedGemmGroupSearchState:
|
||||
"""
|
||||
Create an initial search state for grouped gemm.
|
||||
|
||||
:return: A new search state with initial values
|
||||
:rtype: GroupedGemmGroupSearchState
|
||||
"""
|
||||
return GroupedGemmGroupSearchState(
|
||||
start_group_idx=Int32(0),
|
||||
tile_count_prev_group=Int32(0),
|
||||
tile_count_searched=Int32(0),
|
||||
)
|
||||
|
||||
|
||||
class GroupedGemmTileSchedulerHelper:
|
||||
"""
|
||||
A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm.
|
||||
|
||||
:param group_count: Number of groups in current grouped gemm problem
|
||||
:type group_count: int
|
||||
:param tile_sched_params: Parameter used to create the tile scheduler this helper works with
|
||||
:type tile_sched_params: PersistentTileSchedulerParams
|
||||
:param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k)
|
||||
:type cluster_tile_shape_mnk: tuple[int, int, int]
|
||||
:param search_state: The initial search state
|
||||
:type search_state: GroupedGemmGroupSearchState
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_count: int,
|
||||
tile_sched_params: PersistentTileSchedulerParams,
|
||||
cluster_tile_shape_mnk: tuple[int, int, int],
|
||||
search_state: GroupedGemmGroupSearchState,
|
||||
) -> None:
|
||||
self.tile_sched_params = tile_sched_params
|
||||
self.group_count = group_count
|
||||
self.lane_idx = cute.arch.lane_idx()
|
||||
self.cluster_tile_shape_mnk = cluster_tile_shape_mnk
|
||||
self.search_state = search_state
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = extract_mlir_values(self.tile_sched_params)
|
||||
values.extend(extract_mlir_values(self.search_state))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: List[ir.Value]
|
||||
) -> "GroupedGemmTileSchedulerHelper":
|
||||
tile_sched_params = new_from_mlir_values(self.tile_sched_params, values)
|
||||
search_state = new_from_mlir_values(self.search_state, values[1:])
|
||||
return GroupedGemmTileSchedulerHelper(
|
||||
self.group_count,
|
||||
tile_sched_params,
|
||||
self.cluster_tile_shape_mnk,
|
||||
search_state,
|
||||
)
|
||||
|
||||
def delinearize_z(
|
||||
self,
|
||||
cta_tile_coord: tuple,
|
||||
problem_shape_mnkl: cute.Tensor,
|
||||
) -> GroupSearchResult:
|
||||
"""
|
||||
Delinearize the linear z index and return GroupSearchResult.
|
||||
|
||||
This function should be used by warps that need to know the CTA tile index on M and N dimensions.
|
||||
|
||||
:param cta_tile_coord: The raw CTA coordinate from tile scheduler
|
||||
:type cta_tile_coord: tuple of Int32
|
||||
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for each group
|
||||
:type problem_shape_mnkl: cute.Tensor
|
||||
:return: The search result containing group index and tile coordinates
|
||||
:rtype: GroupSearchResult
|
||||
"""
|
||||
# delinear the z coord
|
||||
linear_idx = cta_tile_coord[2]
|
||||
group_idx, problem_mnkl = self._group_search_and_load_problem_shape(
|
||||
linear_idx,
|
||||
problem_shape_mnkl,
|
||||
self.search_state.start_group_idx,
|
||||
self.search_state.tile_count_prev_group,
|
||||
)
|
||||
# linear index local to current group
|
||||
cluster_tile_idx_in_current_group = (
|
||||
linear_idx - self.search_state.tile_count_prev_group
|
||||
)
|
||||
cluster_count_m, cluster_count_n, cluster_count_k = cute.ceil_div(
|
||||
(problem_mnkl[0], problem_mnkl[1], problem_mnkl[2]),
|
||||
(
|
||||
self.cluster_tile_shape_mnk[0],
|
||||
self.cluster_tile_shape_mnk[1],
|
||||
self.cluster_tile_shape_mnk[2],
|
||||
),
|
||||
)
|
||||
# decompose to get indices on M and N
|
||||
cta_tile_idx_m, cta_tile_idx_n = self._compute_cta_tile_coord(
|
||||
cluster_tile_idx_in_current_group,
|
||||
cta_tile_coord,
|
||||
cluster_count_m,
|
||||
cluster_count_n,
|
||||
)
|
||||
return GroupSearchResult(
|
||||
group_idx,
|
||||
cta_tile_idx_m,
|
||||
cta_tile_idx_n,
|
||||
problem_mnkl[0],
|
||||
problem_mnkl[1],
|
||||
problem_mnkl[2],
|
||||
cluster_count_k,
|
||||
)
|
||||
|
||||
def search_cluster_tile_count_k(
|
||||
self,
|
||||
cta_tile_coord: tuple,
|
||||
problem_shape_mnkl: cute.Tensor,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group.
|
||||
|
||||
This function should be used by warps that are only interested in the number of tiles along K dimension.
|
||||
|
||||
:param cta_tile_coord: The raw CTA coordinate from tile scheduler
|
||||
:type cta_tile_coord: tuple of Int32
|
||||
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
|
||||
:type problem_shape_mnkl: cute.Tensor
|
||||
:return: A tuple containing cluster count along K dimension and the group index
|
||||
:rtype: Tuple[Int32, Int32]
|
||||
"""
|
||||
group_idx, problem_mnk = self._group_search_and_load_problem_shape(
|
||||
cta_tile_coord[2],
|
||||
problem_shape_mnkl,
|
||||
self.search_state.start_group_idx,
|
||||
self.search_state.tile_count_prev_group,
|
||||
)
|
||||
cluster_count_k = (
|
||||
problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1
|
||||
) // self.cluster_tile_shape_mnk[2]
|
||||
return cluster_count_k, group_idx
|
||||
|
||||
@cute.jit
|
||||
def _prefix_sum(self, value_per_thread: Int32) -> Int32:
|
||||
"""
|
||||
Perform prefix sum within a full warp.
|
||||
|
||||
:param value_per_thread: The value for this thread to contribute to the prefix sum
|
||||
:type value_per_thread: Int32
|
||||
:return: The prefix sum result for this thread
|
||||
:rtype: Int32
|
||||
"""
|
||||
clamp_value = 0
|
||||
idx = 1
|
||||
sum_per_thread = value_per_thread
|
||||
while idx < cute.arch.WARP_SIZE:
|
||||
value = cute.arch.shuffle_sync_up(
|
||||
sum_per_thread, idx, mask_and_clamp=clamp_value
|
||||
)
|
||||
if self.lane_idx >= idx:
|
||||
sum_per_thread += value
|
||||
idx = idx << 1
|
||||
return sum_per_thread
|
||||
|
||||
def _get_problem_for_group(
|
||||
self, problem_shape_mnkl: cute.Tensor, group_idx: Int32
|
||||
) -> cute.Tensor:
|
||||
"""
|
||||
Load gemm problem (m,n,k,l) for the specified group from global memory to register.
|
||||
|
||||
:param problem_shape_mnkl: Tensor in global memory with layout (group_count, 4):(4, 1)
|
||||
:type problem_shape_mnkl: cute.Tensor
|
||||
:param group_idx: The index of the group to load
|
||||
:type group_idx: Int32
|
||||
:return: The problem shape tensor for the specified group
|
||||
:rtype: cute.Tensor
|
||||
"""
|
||||
cur_problem_mnkl = cute.make_fragment(
|
||||
cute.make_layout(4), problem_shape_mnkl.element_type
|
||||
)
|
||||
cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl)
|
||||
return cur_problem_mnkl
|
||||
|
||||
def _get_cluster_tile_count_mn(self, problem_shape: cute.Tensor) -> Int32:
|
||||
"""
|
||||
Compute total cluster count.
|
||||
|
||||
:param problem_shape: Tensor containing problem shape (m, n, k, l)
|
||||
:type problem_shape: cute.Tensor
|
||||
:return: The total cluster tile count for M and N dimensions
|
||||
:rtype: Int32
|
||||
"""
|
||||
cur_ntile_m = (
|
||||
problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1
|
||||
) // self.cluster_tile_shape_mnk[0]
|
||||
cur_ntile_n = (
|
||||
problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1
|
||||
) // self.cluster_tile_shape_mnk[1]
|
||||
cur_ntile_mn = cur_ntile_m * cur_ntile_n
|
||||
return cur_ntile_mn
|
||||
|
||||
def _compute_cta_tile_coord(
|
||||
self,
|
||||
cluster_tile_idx: Int32,
|
||||
cta_tile_coord_in_cluster: tuple,
|
||||
cluster_tile_count_m: Int32,
|
||||
cluster_tile_count_n: Int32,
|
||||
) -> tuple:
|
||||
"""
|
||||
Compute CTA tile indices along M and N dimensions based on the linear index within a group.
|
||||
|
||||
It uses the AlongM mode to decompose the linear index onto M and N dimensions.
|
||||
|
||||
:param cluster_tile_idx: The linear index within a group
|
||||
:type cluster_tile_idx: Int32
|
||||
:param cta_tile_coord_in_cluster: CTA indices along M and N dimensions within a cluster
|
||||
:type cta_tile_coord_in_cluster: tuple of Int32
|
||||
:param cluster_tile_count_m: The number of clusters along M dimension of the matched group
|
||||
:type cluster_tile_count_m: Int32
|
||||
:param cluster_tile_count_n: The number of clusters along N dimension of the matched group
|
||||
:type cluster_tile_count_n: Int32
|
||||
:return: A tuple containing CTA tile indices along M and N dimensions
|
||||
:rtype: tuple of (Int32, Int32)
|
||||
"""
|
||||
cluster_layout_mn = cute.make_layout(
|
||||
(cluster_tile_count_m, cluster_tile_count_n)
|
||||
)
|
||||
(mi, ni) = cluster_layout_mn.get_hier_coord(cluster_tile_idx)
|
||||
cta_tile_idx_m = (
|
||||
mi * self.tile_sched_params.cluster_shape_mn[0]
|
||||
+ cta_tile_coord_in_cluster[0]
|
||||
)
|
||||
cta_tile_idx_n = (
|
||||
ni * self.tile_sched_params.cluster_shape_mn[1]
|
||||
+ cta_tile_coord_in_cluster[1]
|
||||
)
|
||||
return (cta_tile_idx_m, cta_tile_idx_n)
|
||||
|
||||
@cute.jit
|
||||
def _group_search(
|
||||
self,
|
||||
linear_idx: Int32,
|
||||
problem_shape_mnkl: cute.Tensor,
|
||||
init_group_idx: Int32,
|
||||
init_tile_count_searched: Int32,
|
||||
) -> GroupedGemmGroupSearchState:
|
||||
"""
|
||||
Search which group the linear index belongs to.
|
||||
|
||||
:param linear_idx: The linear index to be decomposed
|
||||
:type linear_idx: Int32
|
||||
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
|
||||
:type problem_shape_mnkl: cute.Tensor
|
||||
:param init_group_idx: The group idx to start the search with
|
||||
:type init_group_idx: Int32
|
||||
:param init_tile_count_searched: The number of tiles we have searched
|
||||
:type init_tile_count_searched: Int32
|
||||
:return: The updated search state
|
||||
:rtype: GroupedGemmGroupSearchState
|
||||
"""
|
||||
c_0 = Int32(0).ir_value()
|
||||
last_lane_idx = cute.arch.WARP_SIZE - 1
|
||||
|
||||
tile_count_searched = init_tile_count_searched
|
||||
start_group_idx = init_group_idx
|
||||
not_found = linear_idx >= tile_count_searched
|
||||
tile_count_prev_group = self.search_state.tile_count_prev_group
|
||||
while not_found:
|
||||
# get group to search for current lane
|
||||
cur_group_idx = start_group_idx + self.lane_idx
|
||||
# check if the group to be checked is out of range
|
||||
inside_group_bound = cur_group_idx < self.group_count
|
||||
cur_ntile_mn = c_0
|
||||
if inside_group_bound:
|
||||
# get problem size of current group
|
||||
cur_problem_mnkl = self._get_problem_for_group(
|
||||
problem_shape_mnkl, cur_group_idx
|
||||
)
|
||||
cur_ntile_mn = self._get_cluster_tile_count_mn(cur_problem_mnkl)
|
||||
# compute tile count from beginning to current group(included)
|
||||
total_cluster_tile_count_ps_per_thread = self._prefix_sum(cur_ntile_mn)
|
||||
cluster_tile_count_end_per_thread = (
|
||||
total_cluster_tile_count_ps_per_thread + tile_count_searched
|
||||
)
|
||||
|
||||
group_not_in_window = linear_idx >= cluster_tile_count_end_per_thread
|
||||
hitted_group_idx_in_search_window = cute.arch.popc(
|
||||
cute.arch.vote_ballot_sync(group_not_in_window)
|
||||
)
|
||||
not_found = hitted_group_idx_in_search_window == cute.arch.WARP_SIZE
|
||||
start_group_idx = hitted_group_idx_in_search_window + start_group_idx
|
||||
hit_the_1st_problem_in_search_window = (
|
||||
hitted_group_idx_in_search_window == c_0
|
||||
)
|
||||
tile_count_prev_group = tile_count_searched
|
||||
if hit_the_1st_problem_in_search_window == False:
|
||||
tile_count_prev_group = cute.arch.shuffle_sync(
|
||||
cluster_tile_count_end_per_thread,
|
||||
hitted_group_idx_in_search_window - 1,
|
||||
)
|
||||
|
||||
# If no matched group, then get new_cluster_tile_count_end from last lane
|
||||
# Otherwise, get new_cluster_tile_count_end from the hitted group
|
||||
lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window
|
||||
if not_found:
|
||||
lane_idx_for_cluster_tile_count_end = last_lane_idx
|
||||
tile_count_searched = cute.arch.shuffle_sync(
|
||||
cluster_tile_count_end_per_thread,
|
||||
lane_idx_for_cluster_tile_count_end,
|
||||
)
|
||||
|
||||
return GroupedGemmGroupSearchState(
|
||||
start_group_idx,
|
||||
tile_count_prev_group,
|
||||
tile_count_searched,
|
||||
)
|
||||
|
||||
def _group_search_and_load_problem_shape(
|
||||
self,
|
||||
linear_idx: Int32,
|
||||
problem_shape_mnkl: cute.Tensor,
|
||||
start_group_idx: Int32,
|
||||
tile_count_searched: Int32,
|
||||
) -> Tuple[Int32, cute.Tensor]:
|
||||
"""
|
||||
Perform group search and load problem shape for the matched group.
|
||||
|
||||
:param linear_idx: The linear index to be decomposed
|
||||
:type linear_idx: Int32
|
||||
:param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups
|
||||
:type problem_shape_mnkl: cute.Tensor
|
||||
:param start_group_idx: The group idx to start the search with
|
||||
:type start_group_idx: Int32
|
||||
:param tile_count_searched: The number of tiles we have searched
|
||||
:type tile_count_searched: Int32
|
||||
:return: A tuple containing the final group index and the problem shape tensor
|
||||
:rtype: Tuple[Int32, cute.Tensor]
|
||||
"""
|
||||
self.search_state = self._group_search(
|
||||
linear_idx,
|
||||
problem_shape_mnkl,
|
||||
start_group_idx,
|
||||
tile_count_searched,
|
||||
)
|
||||
# get final group search state
|
||||
final_group_idx = self.search_state.start_group_idx
|
||||
# let's revisit if it's better to broadcast problem_shape_mnk in group_search
|
||||
problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx)
|
||||
return final_group_idx, problem_mnkl
|
||||
174
python/CuTeDSL/cutlass/utils/hardware_info.py
Normal file
174
python/CuTeDSL/cutlass/utils/hardware_info.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cuda.bindings import driver, nvrtc
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
"""
|
||||
This class is used to get the hardware info of given GPU device.
|
||||
It provides methods to get the max active clusters for given cluster size.
|
||||
|
||||
Prerequisite:
|
||||
- CUDA driver is initialized via `driver.cuInit` or other CUDA APIs.
|
||||
- CUDA context is created via `driver.cuCtxCreate` or other CUDA APIs.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class HardwareInfo:
|
||||
"""
|
||||
device_id: CUDA device ID to get the hardware info.
|
||||
"""
|
||||
|
||||
def __init__(self, device_id: int = 0):
|
||||
count = self._checkCudaErrors(driver.cuDeviceGetCount())
|
||||
if device_id >= count:
|
||||
raise ValueError(
|
||||
f"Device ID {device_id} is out of range for device count {count}"
|
||||
)
|
||||
self.device_id = device_id
|
||||
self.device = self._checkCudaErrors(driver.cuDeviceGet(device_id))
|
||||
self.context = self._checkCudaErrors(driver.cuCtxGetCurrent())
|
||||
self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion())
|
||||
|
||||
# Getting the max active clusters for a given cluster size
|
||||
def get_max_active_clusters(self, cluster_size: int) -> int:
|
||||
self._get_device_function()
|
||||
if self._cuda_driver_version_lt(11, 8):
|
||||
raise RuntimeError(
|
||||
"CUDA Driver version < 11.8, cannot get _max_active_clusters"
|
||||
)
|
||||
if cluster_size <= 0 or cluster_size > 32:
|
||||
raise ValueError(
|
||||
f"Cluster size must be between 1 and 32, {cluster_size} is not supported"
|
||||
)
|
||||
|
||||
max_shared_memory_per_block = self._checkCudaErrors(
|
||||
driver.cuDeviceGetAttribute(
|
||||
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
self._checkCudaErrors(
|
||||
driver.cuFuncSetAttribute(
|
||||
self.kernel,
|
||||
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
max_shared_memory_per_block,
|
||||
)
|
||||
)
|
||||
max_dynamic_shared_memory = self._checkCudaErrors(
|
||||
driver.cuOccupancyAvailableDynamicSMemPerBlock(
|
||||
self.kernel, 1, 1 # numBlocks # blockSize
|
||||
)
|
||||
)
|
||||
max_active_blocks = self._checkCudaErrors(
|
||||
driver.cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
self.kernel, 1, max_dynamic_shared_memory # blockSize,
|
||||
)
|
||||
)
|
||||
# allow non-portable cluster size to support detection of non-portable cluster size
|
||||
self._checkCudaErrors(
|
||||
driver.cuFuncSetAttribute(
|
||||
self.kernel,
|
||||
driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
|
||||
1,
|
||||
)
|
||||
)
|
||||
# prepare launch configuration
|
||||
launch_config = driver.CUlaunchConfig()
|
||||
launch_config.blockDimX = 128
|
||||
launch_config.blockDimY = 1
|
||||
launch_config.blockDimZ = 1
|
||||
launch_config.sharedMemBytes = max_dynamic_shared_memory
|
||||
launch_config.numAttrs = 1
|
||||
# max possible cluster size is 32
|
||||
cluster_dims_attr = driver.CUlaunchAttribute()
|
||||
cluster_dims_attr.id = (
|
||||
driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
)
|
||||
value = driver.CUlaunchAttributeValue()
|
||||
value.clusterDim.x = cluster_size
|
||||
value.clusterDim.y = 1
|
||||
value.clusterDim.z = 1
|
||||
cluster_dims_attr.value = value
|
||||
launch_config.attrs = [cluster_dims_attr]
|
||||
launch_config.gridDimX = cluster_size
|
||||
launch_config.gridDimY = max_active_blocks
|
||||
launch_config.gridDimZ = 1
|
||||
|
||||
num_clusters = self._checkCudaErrors(
|
||||
driver.cuOccupancyMaxActiveClusters(self.kernel, launch_config)
|
||||
)
|
||||
return num_clusters
|
||||
|
||||
def get_l2_cache_size_in_bytes(self) -> int:
|
||||
return self._checkCudaErrors(
|
||||
driver.cuDeviceGetAttribute(
|
||||
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
|
||||
def get_device_multiprocessor_count(self) -> int:
|
||||
return self._checkCudaErrors(
|
||||
driver.cuDeviceGetAttribute(
|
||||
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
|
||||
def _checkCudaErrors(self, result) -> None:
|
||||
if result[0].value:
|
||||
raise RuntimeError(
|
||||
"CUDA error code={}({})".format(
|
||||
result[0].value, self._cudaGetErrorEnum(result[0])
|
||||
)
|
||||
)
|
||||
# CUDA APIs always return the status as the first element of the result tuple
|
||||
if len(result) == 1:
|
||||
return None
|
||||
elif len(result) == 2:
|
||||
return result[1]
|
||||
else:
|
||||
return result[1:]
|
||||
|
||||
def _cudaGetErrorEnum(self, error) -> str:
|
||||
if isinstance(error, driver.CUresult):
|
||||
err, name = driver.cuGetErrorName(error)
|
||||
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
|
||||
elif isinstance(error, nvrtc.nvrtcResult):
|
||||
return nvrtc.nvrtcGetErrorString(error)[1]
|
||||
else:
|
||||
raise RuntimeError("Unknown error type: {}".format(error))
|
||||
|
||||
def _cuda_driver_version_ge(self, major: int, minor: int) -> bool:
|
||||
return self.driver_version >= (major * 1000 + 10 * minor)
|
||||
|
||||
def _cuda_driver_version_lt(self, major: int, minor: int) -> bool:
|
||||
return not self._cuda_driver_version_ge(major, minor)
|
||||
|
||||
@cute.kernel
|
||||
def _empty_kernel(self):
|
||||
return
|
||||
|
||||
@cute.jit
|
||||
def _host_function(self):
|
||||
self._empty_kernel().launch(
|
||||
grid=[1, 1, 1],
|
||||
block=[1, 1, 1],
|
||||
)
|
||||
|
||||
# get a empty kernel to compute occupancy
|
||||
def _get_device_function(self) -> None:
|
||||
self.compiled_kernel = cute.compile(self._host_function)
|
||||
self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module
|
||||
self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr
|
||||
195
python/CuTeDSL/cutlass/utils/hopper_helpers.py
Normal file
195
python/CuTeDSL/cutlass/utils/hopper_helpers.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Type, Tuple
|
||||
from enum import Enum
|
||||
|
||||
from cutlass.utils.layout import LayoutEnum
|
||||
from cutlass.cutlass_dsl import (
|
||||
Float16,
|
||||
BFloat16,
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Numeric,
|
||||
NumericMeta,
|
||||
dsl_user_op,
|
||||
)
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu.common import CopyUniversalOp
|
||||
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp
|
||||
from cutlass.cute.nvgpu.warpgroup import (
|
||||
MmaF16BF16Op,
|
||||
MmaF8Op,
|
||||
OperandMajorMode,
|
||||
OperandSource,
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def sm90_get_smem_store_op(
|
||||
layout_d: LayoutEnum,
|
||||
elem_ty_d: Type[Numeric],
|
||||
elem_ty_acc: Type[Numeric],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.CopyAtom:
|
||||
"""
|
||||
Selects the largest vectorized smem store atom available subject to constraint of gmem layout.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
layout_d : LayoutEnum
|
||||
The layout enum of the output tensor D.
|
||||
|
||||
elem_ty_d : Type[Numeric]
|
||||
The element type for output tensor D.
|
||||
|
||||
elem_ty_acc : Type[Numeric]
|
||||
The element type for accumulator.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
|
||||
"""
|
||||
|
||||
def validate_type(ty, ty_name):
|
||||
if not isinstance(ty, NumericMeta):
|
||||
raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")
|
||||
|
||||
validate_type(elem_ty_d, "elem_ty_d")
|
||||
validate_type(elem_ty_acc, "elem_ty_acc")
|
||||
|
||||
is_m_major = layout_d.is_m_major_c()
|
||||
|
||||
if elem_ty_d.width == 16:
|
||||
return cute.make_copy_atom(
|
||||
StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class SmemCapacity(Enum):
|
||||
SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024
|
||||
|
||||
|
||||
# Dictionary to map compute capability to SMEM capacity
|
||||
SMEM_CAPACITY = {
|
||||
"sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
|
||||
}
|
||||
|
||||
def make_trivial_tiled_mma(
|
||||
a_dtype: Type[Numeric],
|
||||
b_dtype: Type[Numeric],
|
||||
a_leading_mode: OperandMajorMode,
|
||||
b_leading_mode: OperandMajorMode,
|
||||
acc_dtype: Type[Numeric],
|
||||
atom_layout_mnk: Tuple[int, int, int],
|
||||
tiler_mn: Tuple[int, int],
|
||||
) -> cute.TiledMma:
|
||||
"""Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
|
||||
By default, the MMA atom is created with SMEM operand source for A.
|
||||
|
||||
:param a_dtype: Data type of operand A.
|
||||
:type a_dtype: type[Numeric]
|
||||
:param b_dtype: Data type of operand B.
|
||||
:type b_dtype: type[Numeric]
|
||||
:param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
|
||||
:type a_leading_mode: warpgroup.OperandMajorMode
|
||||
:param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
|
||||
:type b_leading_mode: warpgroup.OperandMajorMode
|
||||
:param acc_dtype: Data type of the accumulator.
|
||||
:type acc_dtype: type[Numeric]
|
||||
:param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads.
|
||||
:type atom_layout_mnk: Tuple[int, int, int]
|
||||
:param tiler_mn: The shape (M, N) of the cta tiler.
|
||||
:type tiler_mn: Tuple[int, int]
|
||||
|
||||
:return: A tiled MMA atom.
|
||||
:rtype: cute.TiledMma
|
||||
|
||||
:raises TypeError: If the data type is not supported.
|
||||
"""
|
||||
|
||||
if a_dtype in {Float16, BFloat16}:
|
||||
if cutlass.const_expr(a_dtype != b_dtype):
|
||||
raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}")
|
||||
if cutlass.const_expr(a_dtype.width != b_dtype.width):
|
||||
raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}")
|
||||
|
||||
mma_op = MmaF16BF16Op(
|
||||
a_dtype,
|
||||
acc_dtype,
|
||||
(*tiler_mn, 16),
|
||||
OperandSource.SMEM,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in {
|
||||
Float8E4M3FN,
|
||||
Float8E5M2,
|
||||
}:
|
||||
mma_op = MmaF8Op(
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
acc_dtype,
|
||||
(*tiler_mn, 32),
|
||||
OperandSource.SMEM,
|
||||
a_leading_mode,
|
||||
b_leading_mode,
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}")
|
||||
|
||||
return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk)
|
||||
|
||||
def get_smem_layout_atom(
|
||||
layout: LayoutEnum,
|
||||
element_type: Type[Numeric],
|
||||
major_mode_size: int,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Select the optimal shared memory layout atom based on parameters.
|
||||
|
||||
:param layout: Layout enum of the tensor
|
||||
:type layout: LayoutEnum
|
||||
:param element_type: Data type of the elements
|
||||
:type element_type: type[cutlass.Numeric]
|
||||
:param major_mode_size: Size of the major mode dimension
|
||||
:type major_mode_size: int
|
||||
|
||||
:return: Selected shared memory layout atom kind
|
||||
:rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind
|
||||
"""
|
||||
assert major_mode_size % 8 == 0
|
||||
sw128_num_contiguous_bits = 1024
|
||||
sw64_num_contiguous_bits = 512
|
||||
sw32_num_contiguous_bits = 256
|
||||
major_mode_size_bits = major_mode_size * element_type.width
|
||||
if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
|
||||
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
|
||||
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
|
||||
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
|
||||
if major_mode_size_bits % sw128_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
|
||||
if major_mode_size_bits % sw64_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
|
||||
if major_mode_size_bits % sw32_num_contiguous_bits == 0:
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
|
||||
return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER
|
||||
68
python/CuTeDSL/cutlass/utils/layout.py
Normal file
68
python/CuTeDSL/cutlass/utils/layout.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import warpgroup
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
|
||||
|
||||
class LayoutEnum(Enum):
|
||||
ROW_MAJOR = "row_major"
|
||||
COL_MAJOR = "col_major"
|
||||
|
||||
def mma_major_mode(self):
|
||||
return (
|
||||
tcgen05.OperandMajorMode.K
|
||||
if self == LayoutEnum.ROW_MAJOR
|
||||
else tcgen05.OperandMajorMode.MN
|
||||
)
|
||||
|
||||
def sm90_mma_major_mode(self):
|
||||
return (
|
||||
warpgroup.OperandMajorMode.K
|
||||
if self == LayoutEnum.ROW_MAJOR
|
||||
else warpgroup.OperandMajorMode.MN
|
||||
)
|
||||
|
||||
def is_k_major_a(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
def is_m_major_a(self):
|
||||
return self == LayoutEnum.COL_MAJOR
|
||||
|
||||
def is_k_major_b(self):
|
||||
return self == LayoutEnum.COL_MAJOR
|
||||
|
||||
def is_n_major_b(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
def is_n_major_c(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
def is_m_major_c(self):
|
||||
return self == LayoutEnum.COL_MAJOR
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(tensor: cute.Tensor) -> "LayoutEnum":
|
||||
ret = None
|
||||
if tensor.leading_dim == 1:
|
||||
ret = LayoutEnum.ROW_MAJOR
|
||||
elif tensor.leading_dim == 0:
|
||||
ret = LayoutEnum.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}")
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
__all__ = ["LayoutEnum"]
|
||||
984
python/CuTeDSL/cutlass/utils/pipeline.py
Normal file
984
python/CuTeDSL/cutlass/utils/pipeline.py
Normal file
@@ -0,0 +1,984 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from cutlass.cutlass_dsl import Boolean, Int32, Int64, T, if_generate, and_, or_
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Agent class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class Agent(enum.Enum):
|
||||
"""
|
||||
Agent indicates what is participating in the pipeline synchronization.
|
||||
"""
|
||||
# Arbitrary grouping of N threads
|
||||
Thread = enum.auto()
|
||||
# Same as AsyncThread, but includes all threads in the block
|
||||
ThreadBlock = enum.auto()
|
||||
# Same as AsyncThread, but includes all threads in the cluster
|
||||
ThreadBlockCluster = enum.auto()
|
||||
|
||||
|
||||
class CooperativeGroup:
|
||||
"""
|
||||
CooperativeGroup contains size and alignment restrictions for an Agent.
|
||||
"""
|
||||
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
|
||||
if agent is Agent.Thread:
|
||||
assert size > 0
|
||||
if size == 32:
|
||||
assert (
|
||||
size == alignment
|
||||
), "Error: Alignment does not match number of threads in a warp."
|
||||
elif size == 128:
|
||||
assert (
|
||||
size == alignment
|
||||
), "Error: Alignment does not match number of threads in a warpgroup."
|
||||
elif agent is Agent.ThreadBlock:
|
||||
assert False, "Error: Not yet supported."
|
||||
elif agent is Agent.ThreadBlockCluster:
|
||||
assert False, "Error: Not yet supported."
|
||||
else:
|
||||
# Should never reach this state
|
||||
size = 0
|
||||
|
||||
if size <= 0:
|
||||
raise ValueError(
|
||||
"Error: The number of threads in a CooperativeGroup must be more than 0."
|
||||
)
|
||||
|
||||
# Size indicates how many threads are participating in this CooperativeGroup
|
||||
self.size = size
|
||||
# Agent indicates the type of thread group
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class _PipelineOp(enum.Enum):
|
||||
"""
|
||||
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
|
||||
"""
|
||||
# async-threads
|
||||
AsyncThread = enum.auto()
|
||||
# Blackwell (SM100a) MMA instruction
|
||||
TCGen05Mma = enum.auto()
|
||||
# Tensor Memory Accelerator load
|
||||
TmaLoad = enum.auto()
|
||||
# TMA Store consuming smem produced by AsyncThread
|
||||
TmaStore = enum.auto()
|
||||
|
||||
|
||||
def _get_pipeline_op(type_str):
|
||||
return _PipelineOp(type_str)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SyncObjectArray class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class SyncObjectArray(ABC):
|
||||
"""
|
||||
SyncObjectArray is an abstract base class for different types of hardware synchronizations (e.g. smem barriers, named barriers, fences)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def arrive(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_barrier(self):
|
||||
pass
|
||||
|
||||
|
||||
class MbarrierArray(SyncObjectArray):
|
||||
"""
|
||||
MbarrierArray implements an abstraction for an array of smem barriers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: int,
|
||||
agent: tuple[_PipelineOp, CooperativeGroup],
|
||||
tx_count: int = 0,
|
||||
):
|
||||
self.barrier_storage = barrier_storage
|
||||
self.tx_count = tx_count
|
||||
self.num_stages = num_stages
|
||||
self.op_type, self.cg = agent
|
||||
self.arrive_count = self.cg.size
|
||||
|
||||
if self.num_stages <= 0:
|
||||
raise ValueError("Error: Mbarrier stage count must be greater than 0.")
|
||||
if self.arrive_count <= 0:
|
||||
raise ValueError("Error: Mbarrier arrive count must be greater than 0.")
|
||||
if self.op_type is _PipelineOp.TmaLoad and self.tx_count <= 0:
|
||||
raise ValueError(
|
||||
"Error: Mbarrier tx count must be greater than 0 for TMA ops."
|
||||
)
|
||||
|
||||
# Using a tensor to store mbarrier i64 ptrs
|
||||
self.mbarrier_array = cute.make_fragment(cute.make_layout(num_stages), Int64)
|
||||
for i in range(num_stages):
|
||||
self.mbarrier_array[i] = _cute_ir.ptrtoint(
|
||||
T.i64(), (self.barrier_storage + i).value
|
||||
)
|
||||
|
||||
# Mbarrier initialization in constructor
|
||||
self.mbarrier_init()
|
||||
|
||||
# Mbarrier initialization
|
||||
def mbarrier_init(self):
|
||||
"""
|
||||
Initializes an array of mbarriers using warp 0.
|
||||
"""
|
||||
def then_body():
|
||||
for index in range(self.num_stages):
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), self.arrive_count
|
||||
)
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
if_generate(warp_idx == 0, then_body)
|
||||
|
||||
def arrive(self, index: int, dst: int):
|
||||
"""
|
||||
Select the arrive corresponding to this MbarrierArray's PipelineOp
|
||||
:param index: Index of the mbarrier in the array to arrive on
|
||||
:type index: int
|
||||
:param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. When None, both TCGen05Mma and AsyncThread will arrive on their local mbarrier.
|
||||
- For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3).
|
||||
- For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster).
|
||||
:type dst: int | None
|
||||
"""
|
||||
if self.op_type is _PipelineOp.AsyncThread:
|
||||
self.arrive_mbarrier(index, dst)
|
||||
elif self.op_type is _PipelineOp.TCGen05Mma:
|
||||
self.arrive_tcgen05mma(index, dst)
|
||||
elif self.op_type in [_PipelineOp.TmaLoad]:
|
||||
self.arrive_and_expect_tx(index, self.tx_count)
|
||||
else:
|
||||
print(_get_pipeline_op(self.op_type))
|
||||
assert False, "Error: MbarrierArray is not supported for this PipelineOp."
|
||||
|
||||
def arrive_mbarrier(self, index: int, dst_rank: int):
|
||||
if dst_rank is None:
|
||||
cute.arch.mbarrier_arrive(_mbarrier_i64_to_ptr(self.mbarrier_array[index]))
|
||||
else:
|
||||
cute.arch.mbarrier_arrive(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), dst_rank
|
||||
)
|
||||
|
||||
def arrive_tcgen05mma(self, index: int, mask: int):
|
||||
if mask is None:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index])
|
||||
)
|
||||
else:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]),
|
||||
mask,
|
||||
cute.nvgpu.tcgen05.CtaGroup.TWO,
|
||||
)
|
||||
|
||||
def arrive_and_expect_tx(self, index: int, tx_count: int):
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_tx_bytes(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), tx_count
|
||||
)
|
||||
|
||||
def try_wait(self, index: int, phase: int):
|
||||
return cute.arch.mbarrier_try_wait(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase
|
||||
)
|
||||
|
||||
def wait(self, index: int, phase: int):
|
||||
cute.arch.mbarrier_wait(_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase)
|
||||
|
||||
def get_barrier(self, index: int) -> cute.Pointer:
|
||||
return _mbarrier_i64_to_ptr(self.mbarrier_array[index])
|
||||
|
||||
|
||||
class TmaStoreFence(SyncObjectArray):
|
||||
"""
|
||||
TmaStoreFence is used for a multi-stage epilogue buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_stages: int = 0,
|
||||
):
|
||||
if num_stages <= 0:
|
||||
raise ValueError("Mbarrier stage count must be greater than 0.")
|
||||
|
||||
self.num_stages = num_stages
|
||||
|
||||
def arrive(self):
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
|
||||
def wait(self):
|
||||
cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True)
|
||||
|
||||
# TmaStoreFence doesn't have mbarriers
|
||||
def get_barrier(self):
|
||||
assert (
|
||||
False
|
||||
), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier."
|
||||
|
||||
def tail(self):
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# PipelineState class
|
||||
##############################################################################
|
||||
|
||||
|
||||
class PipelineUserType(enum.Enum):
|
||||
Producer = enum.auto()
|
||||
Consumer = enum.auto()
|
||||
|
||||
|
||||
class PipelineState:
|
||||
"""
|
||||
Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, stages: int, count, index, phase):
|
||||
self._stages = stages
|
||||
self._count = count
|
||||
self._index = index
|
||||
self._phase = phase
|
||||
|
||||
def clone(self) -> "PipelineState":
|
||||
return PipelineState(self.stages, self._count, self.index, self.phase)
|
||||
|
||||
@property
|
||||
def index(self) -> Int32:
|
||||
return self._index
|
||||
|
||||
@property
|
||||
def count(self) -> Int32:
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def stages(self) -> int:
|
||||
return self._stages
|
||||
|
||||
@property
|
||||
def phase(self) -> Int32:
|
||||
return self._phase
|
||||
|
||||
def reset_count(self):
|
||||
self._count = Int32(0)
|
||||
|
||||
def advance(self):
|
||||
self._index += 1
|
||||
self._count += 1
|
||||
|
||||
def then_body(index, phase):
|
||||
new_index = Int32(0)
|
||||
new_phase = phase ^ 1
|
||||
return new_index, new_phase
|
||||
|
||||
def else_body(index, phase):
|
||||
return index, phase
|
||||
|
||||
self._index, self._phase = if_generate(
|
||||
self._index == self.stages,
|
||||
then_body,
|
||||
else_body,
|
||||
[self.index, self.phase],
|
||||
[Int32, Int32],
|
||||
)
|
||||
|
||||
def reverse(self):
|
||||
self._index -= 1
|
||||
self._count -= 1
|
||||
|
||||
def then_body(index, phase):
|
||||
new_index = Int32(self.stages - 1)
|
||||
new_phase = phase ^ 1
|
||||
return new_index, new_phase
|
||||
|
||||
def else_body(index, phase):
|
||||
return index, phase
|
||||
|
||||
self._index, self._phase = if_generate(
|
||||
self._index == -1,
|
||||
then_body,
|
||||
else_body,
|
||||
[self.index, self.phase],
|
||||
[Int32, Int32],
|
||||
)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self._count.type, self._index.type, self._phase.type]
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
count = self._count
|
||||
index = self._index
|
||||
phase = self._phase
|
||||
return [count.ir_value(), index.ir_value(), phase.ir_value()]
|
||||
|
||||
# This can be overridden by derived classes
|
||||
def __new_from_mlir_values__(self, values):
|
||||
return PipelineState(
|
||||
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
|
||||
)
|
||||
|
||||
|
||||
def make_pipeline_state(type: PipelineUserType, stages: int):
|
||||
"""
|
||||
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
||||
"""
|
||||
if type is PipelineUserType.Producer:
|
||||
return PipelineState(
|
||||
stages,
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
Int32(1),
|
||||
)
|
||||
elif type is PipelineUserType.Consumer:
|
||||
return PipelineState(
|
||||
stages,
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
Int32(0),
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "Error: invalid PipelineUserType specified for make_pipeline_state."
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Pipeline classes
|
||||
##############################################################################
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineAsync:
|
||||
"""
|
||||
PipelineAsync is a generic pipeline class where both the producer and consumer are
|
||||
AsyncThreads. It also serves as a base class for specialized pipeline classes.
|
||||
"""
|
||||
sync_object_array_full: SyncObjectArray
|
||||
sync_object_array_empty: SyncObjectArray
|
||||
num_stages: Int32
|
||||
producer_mask: Int32
|
||||
consumer_mask: Int32
|
||||
|
||||
@staticmethod
|
||||
def _make_sync_object_array(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
agent: tuple[_PipelineOp, CooperativeGroup],
|
||||
tx_count: int = 0,
|
||||
) -> SyncObjectArray:
|
||||
"""
|
||||
Returns a SyncObjectArray corresponding to an agent's PipelineOp.
|
||||
"""
|
||||
if agent[0] in [
|
||||
_PipelineOp.AsyncThread,
|
||||
_PipelineOp.TmaLoad,
|
||||
_PipelineOp.TCGen05Mma,
|
||||
]:
|
||||
return MbarrierArray(
|
||||
barrier_storage=barrier_storage,
|
||||
num_stages=num_stages,
|
||||
agent=agent,
|
||||
tx_count=tx_count,
|
||||
)
|
||||
elif agent[0] is _PipelineOp.TmaStore:
|
||||
# Path taken for AsyncTmaStore
|
||||
return TmaStoreFence(num_stages=num_stages)
|
||||
else:
|
||||
assert False, "Error: Invalid PipelineOp specified."
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
producer_mask: Int32 = None,
|
||||
consumer_mask: Int32 = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param producer_mask: Mask for signaling arrives for the producer agent
|
||||
:type producer_mask: Int32 | None
|
||||
:param consumer_mask: Mask for signaling arrives for the consumer agent
|
||||
:type consumer_mask: Int32 | None
|
||||
"""
|
||||
producer_type = _PipelineOp.AsyncThread
|
||||
consumer_type = _PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_array_full = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8), num_stages, producer
|
||||
)
|
||||
sync_object_array_empty = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
pipeline_init_wait()
|
||||
|
||||
return PipelineAsync(
|
||||
sync_object_array_full,
|
||||
sync_object_array_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
|
||||
)
|
||||
|
||||
def producer_try_acquire(self, state: PipelineState):
|
||||
return self.sync_object_array_empty.try_wait(state.index, state.phase)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
self.sync_object_array_full.arrive(state.index, self.producer_mask)
|
||||
|
||||
def consumer_wait(
|
||||
self, state: PipelineState, try_wait_token: Optional[Boolean] = None
|
||||
):
|
||||
if_generate(
|
||||
try_wait_token is None or try_wait_token == 0,
|
||||
lambda: self.sync_object_array_full.wait(state.index, state.phase),
|
||||
)
|
||||
|
||||
def consumer_try_wait(self, state: PipelineState):
|
||||
return self.sync_object_array_full.try_wait(state.index, state.phase)
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
self.sync_object_array_empty.arrive(state.index, self.consumer_mask)
|
||||
|
||||
def producer_get_barrier(self, state: PipelineState) -> cute.Pointer:
|
||||
return self.sync_object_array_full.get_barrier(state.index)
|
||||
|
||||
def producer_tail(self, state: PipelineState):
|
||||
"""
|
||||
Make sure the last used buffer empty signal is visible to producer.
|
||||
Producer tail is usually executed by producer before exit, to avoid dangling
|
||||
mbarrier arrive signals after kernel exit.
|
||||
|
||||
:param state: The pipeline state that points to next useful buffer
|
||||
:type state: PipelineState
|
||||
"""
|
||||
# Assume state contains that next useful buffer
|
||||
# So we only need to advance to num_stages - 1 times to last used buffer
|
||||
for i in range(self.num_stages - 1):
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
|
||||
"""
|
||||
is_signalling_thread: bool
|
||||
|
||||
@staticmethod
|
||||
def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Initialize the empty barrier arrive signal
|
||||
This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread
|
||||
"""
|
||||
# Logic to optimally schedule Empty Arrives
|
||||
cluster_shape_mnk = cta_layout_vmnk.shape
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
|
||||
is_signalling_thread = tidx < cute.size(cluster_shape_mnk)
|
||||
dst_rank = tidx % cute.size(cluster_shape_mnk)
|
||||
m = cluster_shape_mnk[0]
|
||||
|
||||
# Check if same row
|
||||
is_same_row_l = dst_rank % m
|
||||
is_same_row_r = cta_rank_in_cluster % m
|
||||
is_same_row = is_same_row_l == is_same_row_r
|
||||
|
||||
# Check if same column
|
||||
is_same_col_l = dst_rank // m
|
||||
is_same_col_r = cta_rank_in_cluster // m
|
||||
|
||||
is_same_col = is_same_col_l == is_same_col_r
|
||||
|
||||
is_same_row_or_col = or_(is_same_row, is_same_col)
|
||||
is_signalling_thread_final = and_(is_signalling_thread, is_same_row_or_col)
|
||||
|
||||
return dst_rank, is_signalling_thread_final
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
tx_count: int,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
producer_type = _PipelineOp.TmaLoad
|
||||
consumer_type = _PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_array_full = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_array_empty = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
dst_rank, is_signalling_thread = (
|
||||
PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk)
|
||||
)
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
dst_rank = None
|
||||
else:
|
||||
dst_rank = dst_rank
|
||||
|
||||
is_signalling_thread = is_signalling_thread
|
||||
producer_mask = None
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaAsync(
|
||||
sync_object_array_full,
|
||||
sync_object_array_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
dst_rank,
|
||||
is_signalling_thread,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
"""
|
||||
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
||||
"""
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
|
||||
)
|
||||
self.sync_object_array_full.arrive(state.index, self.producer_mask)
|
||||
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
|
||||
"""
|
||||
pass
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
"""
|
||||
TMA consumer release conditionally signals the empty buffer to the producer.
|
||||
"""
|
||||
if_generate(
|
||||
self.is_signalling_thread,
|
||||
lambda: self.sync_object_array_empty.arrive(
|
||||
state.index, self.consumer_mask
|
||||
),
|
||||
)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaUmma(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
|
||||
"""
|
||||
is_leader_cta: bool
|
||||
|
||||
@staticmethod
|
||||
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes a mask for signaling arrivals to multicasting threadblocks.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
block_in_cluster_coord_vmnk_peer = (
|
||||
cta_in_cluster_coord_vmnk[0] ^ 1,
|
||||
*cta_in_cluster_coord_vmnk[1:],
|
||||
)
|
||||
tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1
|
||||
)
|
||||
return (
|
||||
tma_mcast_mask_a
|
||||
| tma_mcast_mask_b
|
||||
| tma_mcast_mask_a_peer
|
||||
| tma_mcast_mask_b_peer
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
|
||||
"""
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
|
||||
mma_coord_vmnk = (
|
||||
bidx % cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidx // cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidy,
|
||||
None,
|
||||
)
|
||||
return mma_coord_vmnk[0] == 0
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
tx_count: int,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
producer_type = _PipelineOp.TmaLoad
|
||||
consumer_type = _PipelineOp.TCGen05Mma
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_array_full = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
||||
)
|
||||
sync_object_array_empty = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# No mcast mask if not using clusters
|
||||
producer_mask = None
|
||||
# All threadblocks are leaders if not using clusters
|
||||
is_leader_cta = True
|
||||
else:
|
||||
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
||||
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
||||
|
||||
consumer_mask = producer_mask
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineTmaUmma(
|
||||
sync_object_array_full,
|
||||
sync_object_array_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
is_leader_cta,
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
self, state: PipelineState, try_acquire_token: Optional[Boolean] = None
|
||||
):
|
||||
"""
|
||||
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
|
||||
"""
|
||||
if_generate(
|
||||
try_acquire_token is None or try_acquire_token == 0,
|
||||
lambda: self.sync_object_array_empty.wait(state.index, state.phase),
|
||||
)
|
||||
if_generate(
|
||||
self.is_leader_cta,
|
||||
lambda: self.sync_object_array_full.arrive(state.index, self.producer_mask),
|
||||
)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineUmmaAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
Computes a mask to signal completion of tmem buffers for 2CTA kernels.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
return cute.make_layout_image_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_peer_cta_rank():
|
||||
"""
|
||||
Computes a mask to signal release of tmem buffers for 2CTA kernels.
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
return cta_rank_in_cluster // 2 * 2
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
cta_layout_vmnk: Optional[cute.Layout] = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
"""
|
||||
producer_type = _PipelineOp.TCGen05Mma
|
||||
consumer_type = _PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_array_full = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8), num_stages, producer
|
||||
)
|
||||
sync_object_array_empty = PipelineAsync._make_sync_object_array(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# Set mask to None if not using clusters (i.e. 1CTA kernels)
|
||||
producer_mask = None
|
||||
else:
|
||||
producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk)
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
|
||||
# Set mask to None if not using 2CTA intructions
|
||||
consumer_mask = None
|
||||
else:
|
||||
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineUmmaAsync(
|
||||
sync_object_array_full,
|
||||
sync_object_array_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
)
|
||||
|
||||
def producer_tail(self, state: PipelineState):
|
||||
"""
|
||||
Make sure the last used buffer empty signal is visible to producer.
|
||||
Producer tail is usually executed by producer before exit, to avoid dangling
|
||||
mbarrier arrive signals after kernel exit.
|
||||
|
||||
:param state: The pipeline state that points to next useful buffer
|
||||
:type state: PipelineState
|
||||
"""
|
||||
cta_rank_in_cluster = cute.arch.make_warp_uniform(
|
||||
cute.arch.block_idx_in_cluster()
|
||||
)
|
||||
is_leader_cta = cta_rank_in_cluster % 2 == 0
|
||||
|
||||
def then_body():
|
||||
# Assume state contains that next useful buffer
|
||||
# So we only need to advance to num_stages - 1 times to last used buffer
|
||||
for i in range(self.num_stages - 1):
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
if_generate(is_leader_cta, then_body)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaStore(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
"""
|
||||
producer_type = _PipelineOp.TmaStore
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
|
||||
sync_object_array_full = PipelineAsync._make_sync_object_array(
|
||||
None, num_stages, producer
|
||||
)
|
||||
|
||||
return PipelineTmaStore(sync_object_array_full, None, num_stages, None, None)
|
||||
|
||||
def producer_acquire(self):
|
||||
self.sync_object_array_full.wait()
|
||||
|
||||
def producer_commit(self):
|
||||
self.sync_object_array_full.arrive()
|
||||
|
||||
def consumer_wait(self):
|
||||
assert False, "Error: PipelineTmaStore does not have a consumer agent."
|
||||
|
||||
def consumer_release(self):
|
||||
assert False, "Error: PipelineTmaStore does not have a consumer agent."
|
||||
|
||||
def producer_tail(self):
|
||||
self.sync_object_array_full.tail()
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Helper functions
|
||||
##############################################################################
|
||||
|
||||
|
||||
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
|
||||
"""
|
||||
Fences the mbarrier init and syncs the threadblock or cluster
|
||||
"""
|
||||
cute.arch.mbarrier_init_fence()
|
||||
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
||||
# If not using clusters, sync the threadblock
|
||||
_sync(Agent.ThreadBlock)
|
||||
else:
|
||||
# If using clusters, sync the cluster
|
||||
_sync(Agent.ThreadBlockCluster)
|
||||
|
||||
|
||||
def _sync(group: Agent):
|
||||
"""
|
||||
Syncs all threads within an agent.
|
||||
"""
|
||||
if group is Agent.Thread:
|
||||
assert False, "Error: Not supported."
|
||||
elif group is Agent.ThreadBlock:
|
||||
cute.arch.sync_threads()
|
||||
elif group is Agent.ThreadBlockCluster:
|
||||
cute.arch.cluster_arrive()
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead."
|
||||
|
||||
|
||||
def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer:
|
||||
"""
|
||||
Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment
|
||||
"""
|
||||
return cute.make_ptr(
|
||||
Int64,
|
||||
val.ir_value(),
|
||||
mem_space=_cute_ir.AddressSpace.smem,
|
||||
assumed_align=8,
|
||||
)
|
||||
217
python/CuTeDSL/cutlass/utils/smem_allocator.py
Normal file
217
python/CuTeDSL/cutlass/utils/smem_allocator.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Type, Union, overload
|
||||
|
||||
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.arch import get_dyn_smem
|
||||
|
||||
|
||||
class SmemAllocator:
|
||||
"""
|
||||
A class for managing shared memory allocation on GPU.
|
||||
|
||||
This class manages a chunk of shared memory and provide APIs for sub-allocation
|
||||
inside the chunk.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
_base : cute.Pointer as i8 typed dynamic value
|
||||
The current base address of the shared memory.
|
||||
|
||||
_allocated_bytes:
|
||||
The bytes allocated in shared memory.
|
||||
|
||||
Methods
|
||||
-------
|
||||
allocate(num_bytes, alignment)
|
||||
Allocates num_bytes in the shared memory with the given byte alignment.
|
||||
|
||||
allocate_value(value_ty, num_elems)
|
||||
Allocates num_elems of value_ty values in the shared memory.
|
||||
|
||||
allocate_tensor(value_ty, layout, alignment)
|
||||
Allocates a tensor in the shared memory with given layout and byte alignment.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This class is responsible for managing the allocation of tensors in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the SmemAllocator instance with dynamic smem base ptr,
|
||||
which is i8 type and aligned to 1024.
|
||||
|
||||
"""
|
||||
self._base = get_dyn_smem(Int8, alignment=1024)
|
||||
self._allocated_bytes = 0
|
||||
|
||||
@overload
|
||||
def allocate(self, size_or_type: int, byte_alignment: int): ...
|
||||
|
||||
@overload
|
||||
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
|
||||
|
||||
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
|
||||
"""
|
||||
Allocates a block of memory with the specified size and byte alignment.
|
||||
|
||||
This method adjusts the base cute.Pointer to ensure that the allocated memory
|
||||
is aligned according to the specified byte alignment. It updates the internal
|
||||
state to reflect the new base cute.Pointer and the total allocated bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size_or_type : int or struct
|
||||
The number of bytes to allocate or struct class.
|
||||
byte_alignment : int
|
||||
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
|
||||
|
||||
Returns
|
||||
----------
|
||||
A cute.Pointer to the start of the allocated memory block or struct instance.
|
||||
|
||||
Raises
|
||||
----------
|
||||
ValueError
|
||||
If num_bytes is negative or if byte_alignmemt is less than 1.
|
||||
"""
|
||||
|
||||
if isinstance(size_or_type, cute.struct):
|
||||
alignment = max(byte_alignment, size_or_type.__alignof__())
|
||||
base_ptr = self.allocate(size_or_type.__sizeof__(), alignment)
|
||||
return size_or_type(base_ptr)
|
||||
|
||||
num_bytes = size_or_type
|
||||
if num_bytes < 0:
|
||||
raise ValueError("num_bytes must be non-negative")
|
||||
if byte_alignment < 1:
|
||||
raise ValueError("byte_alignment must be at least 1")
|
||||
|
||||
self._base = self._base.align(byte_alignment)
|
||||
ptr = self._base
|
||||
self._base += num_bytes
|
||||
if self._allocated_bytes % byte_alignment != 0:
|
||||
self._allocated_bytes += (
|
||||
byte_alignment - self._allocated_bytes % byte_alignment
|
||||
)
|
||||
self._allocated_bytes += num_bytes
|
||||
return ptr
|
||||
|
||||
def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1):
|
||||
"""
|
||||
Allocates num_elems values of element_type in shared memory.
|
||||
|
||||
This method calls allocate() to return a byte ptr, pointing to start of shared
|
||||
memory. Then calls cute.recast_ptr() to recast this byte cute.Pointer to element_type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_type : Type[Numeric]
|
||||
The type of the values in the tensor.
|
||||
num_elems : int, optional
|
||||
The number of elements for each allocation. Defaults to 1.
|
||||
|
||||
Returns
|
||||
----------
|
||||
A value_type cute.Pointer to the start of the allocated memory block.
|
||||
|
||||
Raises
|
||||
----------
|
||||
ValueError
|
||||
If num_elems is less than 1.
|
||||
"""
|
||||
if num_elems < 1:
|
||||
raise ValueError("num_elems must be at least 1")
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
f"value_ty must be a type of Numeric, but got {element_type}"
|
||||
)
|
||||
|
||||
ptr = self.allocate(
|
||||
element_type.width // 8 * num_elems, element_type.width // 8
|
||||
)
|
||||
|
||||
return cute.recast_ptr(ptr, dtype=element_type)
|
||||
|
||||
def allocate_tensor(
|
||||
self,
|
||||
element_type: Type[Numeric],
|
||||
layout: Union[int, cute.Layout, cute.ComposedLayout],
|
||||
byte_alignment: int = 1,
|
||||
swizzle: cute.Swizzle = None,
|
||||
):
|
||||
"""
|
||||
Allocates a tensor in the shared memory with value type, layout and byte alignment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_type : Type[Numeric]
|
||||
The type of the values in the tensor.
|
||||
layout : int | DynamicInt | cute.Layout | cute.ComposedLayout
|
||||
The layout of the tensor.
|
||||
byte_alignment : int, optional
|
||||
The byte alignment requirement for the allocation. Defaults to 1 (no alignment).
|
||||
swizzle : cute.Swizzle
|
||||
A swizzle for the iterator (for position-dependent swizzling).
|
||||
|
||||
Returns
|
||||
-------
|
||||
tensor : cute.Tensor
|
||||
The allocated tensor with specified value type, layout and byte alignment.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The base address is updated to point to the next available memory location.
|
||||
"""
|
||||
if not isinstance(element_type, NumericMeta):
|
||||
raise TypeError(
|
||||
f"value_ty must be a type of Numeric, but got {element_type}"
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(layout, cute.ComposedLayout)
|
||||
and isinstance(layout.inner, cute.Swizzle)
|
||||
) and (swizzle is not None):
|
||||
raise TypeError(
|
||||
f"iterator swizzle with swizzle layout is currently not supported"
|
||||
)
|
||||
|
||||
if isinstance(layout, int):
|
||||
layout = cute.make_layout(layout)
|
||||
|
||||
profile = layout(0)
|
||||
if isinstance(profile, tuple):
|
||||
raise TypeError(
|
||||
f"cannot allocate a shared memory tensor with a non-integer iterator"
|
||||
)
|
||||
|
||||
if not cute.is_static(layout.type):
|
||||
raise NotImplementedError(f"dynamic layout is not supported: {layout.type}")
|
||||
|
||||
# At least align the allocation to the natural alignment given by the element type
|
||||
if element_type.width // 8 > byte_alignment:
|
||||
byte_alignment = element_type.width // 8
|
||||
|
||||
# Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned
|
||||
cosize_in_bits = cute.cosize(layout) * element_type.width
|
||||
assert isinstance(cosize_in_bits, int)
|
||||
if cosize_in_bits % 8 != 0:
|
||||
raise ValueError("invalid allocation that is not byte-aligned")
|
||||
|
||||
num_bytes = cosize_in_bits // 8
|
||||
ptr = self.allocate(num_bytes, byte_alignment)
|
||||
ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type)
|
||||
res = cute.make_tensor(ptr, layout)
|
||||
return res
|
||||
384
python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py
Normal file
384
python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Integer,
|
||||
Int32,
|
||||
min,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass.cute as cute
|
||||
|
||||
##############################################################################
|
||||
# Static persistent tile scheduler
|
||||
##############################################################################
|
||||
|
||||
|
||||
class WorkTileInfo:
|
||||
"""A class to represent information about a work tile.
|
||||
|
||||
:ivar tile_idx: The index of the tile.
|
||||
:type tile_idx: cute.Coord
|
||||
:ivar is_valid_tile: Whether the tile is valid.
|
||||
:type is_valid_tile: Boolean
|
||||
"""
|
||||
|
||||
def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean):
|
||||
self._tile_idx = tile_idx
|
||||
self._is_valid_tile = Boolean(is_valid_tile)
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.tile_idx)
|
||||
values.extend(extract_mlir_values(self.is_valid_tile))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
|
||||
assert len(values) == 4
|
||||
new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1])
|
||||
new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]])
|
||||
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
||||
|
||||
@property
|
||||
def is_valid_tile(self) -> Boolean:
|
||||
"""Check latest tile returned by the scheduler is valid or not. Any scheduling
|
||||
requests after all tasks completed will return an invalid tile.
|
||||
|
||||
:return: The validity of the tile.
|
||||
:rtype: Boolean
|
||||
"""
|
||||
return self._is_valid_tile
|
||||
|
||||
@property
|
||||
def tile_idx(self) -> cute.Coord:
|
||||
"""
|
||||
Get the index of the tile.
|
||||
|
||||
:return: The index of the tile.
|
||||
:rtype: cute.Coord
|
||||
"""
|
||||
return self._tile_idx
|
||||
|
||||
|
||||
class PersistentTileSchedulerParams:
|
||||
"""A class to represent parameters for a persistent tile scheduler.
|
||||
|
||||
This class is designed to manage and compute the layout of clusters and tiles
|
||||
in a batched gemm problem.
|
||||
|
||||
:ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).
|
||||
:type cluster_shape_mn: tuple
|
||||
:ivar problem_layout_ncluster_mnl: Layout of the problem in terms of
|
||||
number of clusters in (m, n, l) dimensions.
|
||||
:type problem_layout_ncluster_mnl: cute.Layout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
problem_shape_ntile_mnl: cute.Shape,
|
||||
cluster_shape_mnk: cute.Shape,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""
|
||||
Initializes the PersistentTileSchedulerParams with the given parameters.
|
||||
|
||||
:param problem_shape_ntile_mnl: The shape of the problem in terms of
|
||||
number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.
|
||||
:type problem_shape_ntile_mnl: cute.Shape
|
||||
:param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions.
|
||||
:type cluster_shape_mnk: cute.Shape
|
||||
|
||||
:raises ValueError: If cluster_shape_k is not 1.
|
||||
"""
|
||||
|
||||
if cluster_shape_mnk[2] != 1:
|
||||
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
|
||||
|
||||
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
|
||||
# cluster_shape_mnk is kept for reconstruction
|
||||
self._cluster_shape_mnk = cluster_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mnk[:2]
|
||||
self._loc = loc
|
||||
|
||||
# By default, we follow m major (col-major) raster order, so make a col-major layout
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
cute.ceil_div(
|
||||
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values, self._values_pos = [], []
|
||||
for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]:
|
||||
obj_values = extract_mlir_values(obj)
|
||||
values += obj_values
|
||||
self._values_pos.append(len(obj_values))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
obj_list = []
|
||||
for obj, n_items in zip(
|
||||
[self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos
|
||||
):
|
||||
obj_list.append(new_from_mlir_values(obj, values[:n_items]))
|
||||
values = values[n_items:]
|
||||
return PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
|
||||
|
||||
@dsl_user_op
|
||||
def get_grid_shape(
|
||||
self, max_active_clusters: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
"""
|
||||
Computes the grid shape based on the maximum active clusters allowed.
|
||||
|
||||
:param max_active_clusters: The maximum number of active clusters that
|
||||
can run in one wave.
|
||||
:type max_active_clusters: Int32
|
||||
|
||||
:return: A tuple containing the grid shape in (m, n, persistent_clusters).
|
||||
- m: self.cluster_shape_m.
|
||||
- n: self.cluster_shape_n.
|
||||
- persistent_clusters: Number of persistent clusters that can run.
|
||||
"""
|
||||
|
||||
# Total ctas in problem size
|
||||
num_ctas_mnl = tuple(
|
||||
x * y
|
||||
for x, y in zip(
|
||||
self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn
|
||||
)
|
||||
) + (self.problem_layout_ncluster_mnl.shape[2],)
|
||||
|
||||
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
||||
|
||||
num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip)
|
||||
# Total ctas that can run in one wave
|
||||
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
|
||||
|
||||
num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave)
|
||||
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
|
||||
|
||||
return (*self.cluster_shape_mn, num_persistent_clusters)
|
||||
|
||||
|
||||
class StaticPersistentTileScheduler:
|
||||
"""A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.
|
||||
|
||||
:ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:ivar num_persistent_clusters: Number of persistent clusters that can be launched
|
||||
:type num_persistent_clusters: Int32
|
||||
:ivar cta_id_in_cluster: ID of the CTA within its cluster
|
||||
:type cta_id_in_cluster: cute.Coord
|
||||
:ivar _num_tiles_executed: Counter for executed tiles
|
||||
:type _num_tiles_executed: Int32
|
||||
:ivar _current_work_linear_idx: Current cluster index
|
||||
:type _current_work_linear_idx: Int32
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: PersistentTileSchedulerParams,
|
||||
num_persistent_clusters: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
cta_id_in_cluster: cute.Coord,
|
||||
num_tiles_executed: Int32,
|
||||
):
|
||||
"""
|
||||
Initializes the StaticPersistentTileScheduler with the given parameters.
|
||||
|
||||
:param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param num_persistent_clusters: Number of persistent clusters that can be launched.
|
||||
:type num_persistent_clusters: Int32
|
||||
:param current_work_linear_idx: Current cluster index.
|
||||
:type current_work_linear_idx: Int32
|
||||
:param cta_id_in_cluster: ID of the CTA within its cluster.
|
||||
:type cta_id_in_cluster: cute.Coord
|
||||
:param num_tiles_executed: Counter for executed tiles.
|
||||
:type num_tiles_executed: Int32
|
||||
"""
|
||||
self.params = params
|
||||
self.num_persistent_clusters = num_persistent_clusters
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self.cta_id_in_cluster = cta_id_in_cluster
|
||||
self._num_tiles_executed = num_tiles_executed
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.num_persistent_clusters)
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self.cta_id_in_cluster))
|
||||
values.extend(extract_mlir_values(self._num_tiles_executed))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: list[ir.Value]
|
||||
) -> "StaticPersistentTileScheduler":
|
||||
assert len(values) == 6
|
||||
new_num_persistent_clusters = new_from_mlir_values(
|
||||
self.num_persistent_clusters, [values[0]]
|
||||
)
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[1]]
|
||||
)
|
||||
new_cta_id_in_cluster = new_from_mlir_values(
|
||||
self.cta_id_in_cluster, values[2:5]
|
||||
)
|
||||
new_num_tiles_executed = new_from_mlir_values(
|
||||
self._num_tiles_executed, [values[5]]
|
||||
)
|
||||
return StaticPersistentTileScheduler(
|
||||
self.params,
|
||||
new_num_persistent_clusters,
|
||||
new_current_work_linear_idx,
|
||||
new_cta_id_in_cluster,
|
||||
new_num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@dsl_user_op
|
||||
@staticmethod
|
||||
def create(
|
||||
params: PersistentTileSchedulerParams,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Initialize the static persistent tile scheduler.
|
||||
|
||||
:param params: Parameters for the persistent
|
||||
tile scheduler.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param block_idx: The 3d block index in the format (bidx, bidy, bidz).
|
||||
:type block_idx: Tuple[Integer, Integer, Integer]
|
||||
:param grid_dim: The 3d grid dimensions for kernel launch.
|
||||
:type grid_dim: Tuple[Integer, Integer, Integer]
|
||||
|
||||
:return: A StaticPersistentTileScheduler object.
|
||||
:rtype: StaticPersistentTileScheduler
|
||||
"""
|
||||
params = params
|
||||
|
||||
# Calculate the number of persistent clusters by dividing the total grid size
|
||||
# by the number of CTAs per cluster
|
||||
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
|
||||
params.cluster_shape_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
|
||||
# Initialize workload index equals to the cluster index in the grid
|
||||
current_work_linear_idx = Int32(bidz)
|
||||
|
||||
# CTA id in the cluster
|
||||
cta_id_in_cluster = (
|
||||
Int32(bidx % params.cluster_shape_mn[0]),
|
||||
Int32(bidy % params.cluster_shape_mn[1]),
|
||||
Int32(0),
|
||||
)
|
||||
# Initialize number of tiles executed to zero
|
||||
num_tiles_executed = Int32(0)
|
||||
return StaticPersistentTileScheduler(
|
||||
params,
|
||||
num_persistent_clusters,
|
||||
current_work_linear_idx,
|
||||
cta_id_in_cluster,
|
||||
num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: PersistentTileSchedulerParams,
|
||||
max_active_clusters: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
"""Calculates the grid shape to be launched on GPU using problem shape,
|
||||
threadblock shape, and active cluster size.
|
||||
|
||||
:param params: Parameters for grid shape calculation.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param max_active_clusters: Maximum active clusters allowed.
|
||||
:type max_active_clusters: Int32
|
||||
|
||||
:return: The calculated 3d grid shape.
|
||||
:rtype: Tuple[Integer, Integer, Integer]
|
||||
"""
|
||||
|
||||
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
|
||||
|
||||
# private method
|
||||
def _get_current_work_for_linear_idx(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> WorkTileInfo:
|
||||
"""Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.
|
||||
|
||||
:param current_work_linear_idx: The linear index of the current work.
|
||||
:type current_work_linear_idx: Int32
|
||||
|
||||
:return: An object containing information about the current tile coordinates
|
||||
and validity status.
|
||||
:rtype: WorkTileInfo
|
||||
"""
|
||||
|
||||
is_valid = current_work_linear_idx < cute.size(
|
||||
self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord(
|
||||
current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# cur_tile_coord is a tuple of i32 values
|
||||
cur_tile_coord = tuple(
|
||||
Int32(x) * Int32(z) + Int32(y)
|
||||
for x, y, z in zip(
|
||||
cur_cluster_coord,
|
||||
self.cta_id_in_cluster,
|
||||
(*self.params.cluster_shape_mn, Int32(1)),
|
||||
)
|
||||
)
|
||||
|
||||
return WorkTileInfo(cur_tile_coord, is_valid)
|
||||
|
||||
@dsl_user_op
|
||||
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self._get_current_work_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self.get_current_work(loc=loc, ip=ip)
|
||||
|
||||
@dsl_user_op
|
||||
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
|
||||
self._current_work_linear_idx += Int32(advance_count) * Int32(
|
||||
self.num_persistent_clusters
|
||||
)
|
||||
self._num_tiles_executed += Int32(1)
|
||||
|
||||
@property
|
||||
def num_tiles_executed(self) -> Int32:
|
||||
return self._num_tiles_executed
|
||||
140
python/CuTeDSL/cutlass/utils/tensormap_manager.py
Normal file
140
python/CuTeDSL/cutlass/utils/tensormap_manager.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import const_expr
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""
|
||||
Enum class defining tensor map update modes.
|
||||
|
||||
Modes:
|
||||
GMEM: Update tensormap in global memory
|
||||
SMEM: Load tensormap from global memory to shared memory,
|
||||
update it in shared memory, then store back to global memory
|
||||
"""
|
||||
|
||||
GMEM = auto() # Update tensormap in global memory
|
||||
SMEM = auto() # Update tensormap in shared memory
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMapManager:
|
||||
"""
|
||||
Manages TensorMap operations including initialization and updates.
|
||||
Provides utilities to convert tensormap pointer to across different memory spaces.
|
||||
"""
|
||||
|
||||
tensormap_update_mode: TensorMapUpdateMode
|
||||
bytes_per_tensormap: int
|
||||
|
||||
# convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap.
|
||||
# address_space: the address space of the resulting tensormap pointer. It could be generic or gmem
|
||||
def get_tensormap_ptr(
|
||||
self,
|
||||
ptr: cute.Pointer,
|
||||
address_space=_cute_ir.AddressSpace.gmem,
|
||||
) -> cute.Pointer:
|
||||
if address_space not in [
|
||||
_cute_ir.AddressSpace.gmem,
|
||||
_cute_ir.AddressSpace.generic,
|
||||
]:
|
||||
raise ValueError(f"Invalid address space: {address_space} for tensormap")
|
||||
|
||||
gmem_ptr_i64 = ptr.toint().ir_value()
|
||||
gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get(
|
||||
self.bytes_per_tensormap, gmem_ptr_i64.type.width
|
||||
)
|
||||
gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64)
|
||||
gmem_ptr_ty = _cute_ir.PtrType.get(
|
||||
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
|
||||
address_space,
|
||||
self.bytes_per_tensormap,
|
||||
)
|
||||
return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align)
|
||||
|
||||
# init tensormap pointed by dst_ptr with the one inside copy_atom.
|
||||
# dst_ptr should be pointing to a global memory location or a smem location
|
||||
# warp_id specifies which warp to perform the initialization
|
||||
@cute.jit
|
||||
def init_tensormap_from_atom(
|
||||
self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, warp_id: int
|
||||
) -> None:
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
if warp_idx == warp_id:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
|
||||
cute.arch.sync_warp()
|
||||
return
|
||||
|
||||
# Perform a fence operation to ensure previous `init_tensormap_from_atom` calls have been completed
|
||||
def fence_tensormap_initialization(
|
||||
self,
|
||||
) -> None:
|
||||
if self.tensormap_update_mode == TensorMapUpdateMode.GMEM:
|
||||
cute.arch.fence_acq_rel_cta()
|
||||
return
|
||||
|
||||
# Perform a fence operation to ensure previous `update_tensormap` calls have been completed
|
||||
def fence_tensormap_update(
|
||||
self,
|
||||
tensormap_ptr: cute.Pointer,
|
||||
) -> None:
|
||||
cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr)
|
||||
return
|
||||
|
||||
@cute.jit
|
||||
def update_tensormap(
|
||||
self,
|
||||
tensor_gmem: Tuple[cute.Tensor, ...],
|
||||
tma_copy_atom: Tuple[cute.CopyAtom, ...],
|
||||
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
|
||||
warp_id: int,
|
||||
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
|
||||
) -> None:
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
# updates before touching tensormap in global memory
|
||||
if warp_idx == warp_id:
|
||||
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
||||
for copy_atom, tensor, smem_ptr in zip(
|
||||
tma_copy_atom, tensor_gmem, tensormap_smem_ptr
|
||||
):
|
||||
cute.nvgpu.cpasync.update_tma_descriptor(
|
||||
copy_atom, tensor, smem_ptr
|
||||
)
|
||||
# wait until it's safe to update tensormap in global memory
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
cute.arch.sync_warp()
|
||||
# updates to tensormap in global memory
|
||||
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
||||
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
||||
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
||||
else:
|
||||
for copy_atom, tensor, gmem_ptr in zip(
|
||||
tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
|
||||
):
|
||||
cute.nvgpu.cpasync.update_tma_descriptor(
|
||||
copy_atom, tensor, gmem_ptr
|
||||
)
|
||||
cute.arch.sync_warp()
|
||||
cute.nvgpu.cpasync.fence_tma_desc_release()
|
||||
37
python/CuTeDSL/cutlass_dsl/__init__.py
Normal file
37
python/CuTeDSL/cutlass_dsl/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from .cutlass import *
|
||||
|
||||
from ..base_dsl.ast_helpers import (
|
||||
loop_selector,
|
||||
if_selector,
|
||||
if_executor,
|
||||
while_selector,
|
||||
while_executor,
|
||||
range_constexpr,
|
||||
range_dynamic,
|
||||
const_expr,
|
||||
dynamic_expr,
|
||||
assert_executor,
|
||||
bool_cast,
|
||||
)
|
||||
|
||||
from ..base_dsl import *
|
||||
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
|
||||
from ..base_dsl.typing import _binary_op_type_promote
|
||||
from ..base_dsl._mlir_helpers.gpu import *
|
||||
from ..base_dsl._mlir_helpers.op import dsl_user_op
|
||||
from ..base_dsl.runtime import *
|
||||
from ..base_dsl.runtime import cuda as cuda_helpers
|
||||
from ..base_dsl.compiler import compile
|
||||
from ..base_dsl.runtime.dlpack_runtime import *
|
||||
from ..base_dsl.runtime.jit_arg_adapters import *
|
||||
1322
python/CuTeDSL/cutlass_dsl/cutlass.py
Normal file
1322
python/CuTeDSL/cutlass_dsl/cutlass.py
Normal file
File diff suppressed because it is too large
Load Diff
515
python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py
Normal file
515
python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py
Normal file
@@ -0,0 +1,515 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import List, Tuple
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import scf, arith
|
||||
from cutlass._mlir.extras import types as T
|
||||
|
||||
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
|
||||
from ..base_dsl.ast_helpers import *
|
||||
from ..base_dsl.utils.logger import log
|
||||
from ..base_dsl import typing as t
|
||||
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
|
||||
from . import cutlass as cutlass_dsl
|
||||
|
||||
# =============================================================================
|
||||
# AST Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LoopUnroll(ir.Attribute):
|
||||
def __init__(self, **kwargs):
|
||||
valid_keys = set(["count", "full"])
|
||||
def to_mlir_attr(val):
|
||||
if isinstance(val, bool):
|
||||
return "true" if val else "false"
|
||||
elif isinstance(val, int):
|
||||
return f"{val} : i32"
|
||||
else:
|
||||
raise DSLNotImplemented(f"{type(val)} is not supported")
|
||||
|
||||
cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs}
|
||||
if kwargs.get("count", None) == 1:
|
||||
cfg["disable"] = "true"
|
||||
|
||||
unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">"
|
||||
|
||||
super().__init__(
|
||||
ir.Attribute.parse(f"#llvm.loop_annotation<unroll = {unroll}>")
|
||||
)
|
||||
|
||||
|
||||
class ScfGenerator:
|
||||
"""
|
||||
Encapsulates common scf dialect functionality: pack, unpack, and SCF execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def fill_none(ir_values, unpacked_values):
|
||||
i = 0
|
||||
for idx, item in enumerate(unpacked_values):
|
||||
if item is not None:
|
||||
unpacked_values[idx] = ir_values[i]
|
||||
i += 1
|
||||
|
||||
@staticmethod
|
||||
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
|
||||
"""
|
||||
Convert region_result to a list if it is not already a list
|
||||
If region_result is a list, return it as is.
|
||||
If region_result is None, return an empty list.
|
||||
If region_result is not a list, return a list containing region_result as the only element.
|
||||
"""
|
||||
if region_result is None:
|
||||
region_result_list = []
|
||||
elif not isinstance(region_result, list):
|
||||
region_result_list = [region_result]
|
||||
else:
|
||||
region_result_list = region_result
|
||||
return region_result_list
|
||||
|
||||
@staticmethod
|
||||
def check_region_result(region_values, ir_values):
|
||||
for i, (expected_value, actual_value) in enumerate(
|
||||
zip(ir_values, region_values)
|
||||
):
|
||||
expected_value_type = get_mlir_types(expected_value)
|
||||
actual_value_type = get_mlir_types(actual_value)
|
||||
if expected_value_type != actual_value_type:
|
||||
return False, i, expected_value_type, actual_value_type
|
||||
return True, -1, None, None
|
||||
|
||||
def scf_execute_dynamic(
|
||||
self,
|
||||
op_type_name: str,
|
||||
used_args: List[Any],
|
||||
mix_iter_args: List[Any],
|
||||
mix_iter_arg_names: List[str],
|
||||
create_op_func: Callable[
|
||||
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
|
||||
],
|
||||
region_builders: List[
|
||||
Callable[
|
||||
[
|
||||
"ir.Operation",
|
||||
List["ir.Value"], # block_args
|
||||
List[Any], # used_args
|
||||
List["ir.Value"], # dyn_yield_ops
|
||||
Dict[int, Tuple[int, int]],
|
||||
List[Any],
|
||||
],
|
||||
Any,
|
||||
]
|
||||
],
|
||||
# block_term_op_builder[region_builder] = scf_op_builder
|
||||
# e.g. scf.ConditionOp for while loop
|
||||
block_term_op_builder: Dict[Callable, Callable] = {},
|
||||
) -> Any:
|
||||
# 1) Unpack
|
||||
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
|
||||
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
|
||||
)
|
||||
# 2) Create the SCF op
|
||||
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
|
||||
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
|
||||
|
||||
# 3) Build the regions
|
||||
for i, builder in enumerate(region_builders):
|
||||
region = op.regions[i]
|
||||
block = region.blocks[0]
|
||||
with ir.InsertionPoint(block):
|
||||
block_args = list(block.arguments)
|
||||
region_result = builder(
|
||||
op,
|
||||
block_args,
|
||||
used_args,
|
||||
dyn_unpacked_values,
|
||||
dyn_indices,
|
||||
dyn_class_types,
|
||||
)
|
||||
|
||||
# Use custom terminator if provided for this builder, otherwise use default YieldOp
|
||||
if builder in block_term_op_builder:
|
||||
# Use the provided terminator generator
|
||||
block_term_op_builder[builder](region_result)
|
||||
else:
|
||||
# Normalize region_result
|
||||
region_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
region_result
|
||||
)
|
||||
# Default behavior - generate YieldOp
|
||||
region_values, unpacked_values, _, _ = (
|
||||
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
|
||||
)
|
||||
|
||||
is_match, mismatch_idx, expected_type, actual_type = (
|
||||
ScfGenerator.check_region_result(region_values, ir_values)
|
||||
)
|
||||
|
||||
if not is_match:
|
||||
# From unpacked index, we need to find the original index
|
||||
original_idx = -1
|
||||
for unpacked_idx, (original_idx, length) in dyn_indices.items():
|
||||
if (
|
||||
mismatch_idx >= original_idx
|
||||
and mismatch_idx < original_idx + length
|
||||
):
|
||||
original_idx = unpacked_idx
|
||||
break
|
||||
raise DSLRuntimeError(
|
||||
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
|
||||
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
|
||||
)
|
||||
scf.YieldOp(region_values)
|
||||
|
||||
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
|
||||
ScfGenerator.fill_none(op.results, unpacked_values)
|
||||
|
||||
# 4) Pack final results
|
||||
final_results = cutlass_dsl.pack_from_irvalue(
|
||||
unpacked_values, dyn_indices, dyn_class_types
|
||||
)
|
||||
|
||||
# 5) Return in a nice pattern
|
||||
if not final_results:
|
||||
return
|
||||
if len(final_results) == 1:
|
||||
return final_results[0]
|
||||
return final_results
|
||||
|
||||
|
||||
def _loop_execute_range_dynamic(
|
||||
func: Callable,
|
||||
start: Any,
|
||||
stop: Any,
|
||||
step: Any,
|
||||
used_args: List[Any] = [],
|
||||
mix_iter_args: List[Any] = [],
|
||||
mix_iter_arg_names: List[str] = [],
|
||||
unroll: int = -1,
|
||||
unroll_full: bool = False,
|
||||
):
|
||||
"""
|
||||
Example: build an scf.for with optional unroll, using our universal helper.
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_for_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
for d in dyn_yield_ops:
|
||||
if not isinstance(d, ir.Value):
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid dyn_yield_ops: {dyn_yield_ops} \n\tExpected ir.Value, got {type(d)}"
|
||||
)
|
||||
|
||||
# Convert Python ints or values to IR constants if needed
|
||||
start_ = t.as_numeric(start)
|
||||
stop_ = t.as_numeric(stop)
|
||||
step_ = t.as_numeric(step)
|
||||
assert start_ is not t.Int32, "Start is required for scf.for"
|
||||
assert stop_ is not t.Int32, "Stop is required for scf.for"
|
||||
assert step_ is not t.Int32, "Step is required for scf.for"
|
||||
start_ = start_.ir_value()
|
||||
stop_ = stop_.ir_value()
|
||||
step_ = step_.ir_value()
|
||||
|
||||
# Possibly attach unroll attributes
|
||||
unroll_attr = None
|
||||
if unroll_full:
|
||||
unroll_attr = LoopUnroll(full=True)
|
||||
elif unroll != -1:
|
||||
unroll_attr = LoopUnroll(count=unroll)
|
||||
log().debug("Unroll attribute: %s", unroll_attr)
|
||||
|
||||
log().debug(
|
||||
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
|
||||
start_,
|
||||
type(start_),
|
||||
stop_,
|
||||
type(stop_),
|
||||
step_,
|
||||
type(step_),
|
||||
)
|
||||
# Create scf.ForOp, passing iteration args if any
|
||||
try:
|
||||
if not dyn_yield_ops:
|
||||
for_op = scf.ForOp(start_, stop_, step_)
|
||||
else:
|
||||
for_op = scf.ForOp(start_, stop_, step_, list(dyn_yield_ops))
|
||||
except Exception as e:
|
||||
yield_ops = "\n".join(
|
||||
f"\t\t{i} => {d} : type : {type(d)}"
|
||||
for i, d in enumerate(dyn_yield_ops)
|
||||
)
|
||||
raise DSLRuntimeError(
|
||||
f"Failed to create scf.ForOp \n\t\tstart={start_}: type : {type(start_)}"
|
||||
f"\n\t\tstop={stop_}: type : {type(stop_)}\n\t\tstep={step_}: type : {type(step_)}"
|
||||
f", \n\tdyn_yield_ops:\n{yield_ops}"
|
||||
) from e
|
||||
|
||||
if unroll_attr is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll_attr
|
||||
|
||||
return for_op
|
||||
|
||||
def for_body_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
):
|
||||
# Insert induction variable at the beginning
|
||||
dyn_yield_ops.insert(0, block_args[0])
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
# scf.ForOp block_args are typically [induction_var, iter_args...]
|
||||
# But MLIR also gives you op.induction_variable
|
||||
iv = t.as_numeric(op.induction_variable)
|
||||
log().debug(
|
||||
"For body builder: %s block_args: %s used_args: %s",
|
||||
iv,
|
||||
block_args,
|
||||
used_args,
|
||||
)
|
||||
if len(block_args) <= 1:
|
||||
# No iteration arguments, or only the induction var
|
||||
func(iv, *used_args)
|
||||
return [] # yield nothing
|
||||
else:
|
||||
# block_args[1:] are iteration variables
|
||||
func_args = [*used_args]
|
||||
func_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args[1:], dyn_indices, dyn_class_types
|
||||
)
|
||||
)
|
||||
updated_func_args = func(iv, *func_args)
|
||||
return updated_func_args
|
||||
|
||||
# Now call the universal SCF executor with a single region builder
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="for",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_iter_args,
|
||||
mix_iter_arg_names=mix_iter_arg_names,
|
||||
create_op_func=create_for_op,
|
||||
region_builders=[for_body_builder],
|
||||
)
|
||||
|
||||
|
||||
def _if_execute_dynamic(
|
||||
pred: "ir.Value",
|
||||
then_block: Callable,
|
||||
else_block: Callable = None,
|
||||
used_args: List[Any] = [],
|
||||
mix_yield_args: List[Any] = [],
|
||||
mix_yield_arg_names: List[str] = [],
|
||||
if_constexpr=None, # ignoring for brevity
|
||||
):
|
||||
"""
|
||||
Build an scf.if with optional else, using our universal helper.
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_if_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
# Assume final result types match the dynamic yields
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
|
||||
pred_ = t.as_numeric(pred)
|
||||
|
||||
if not isinstance(pred_, Boolean):
|
||||
# Convert to Boolean through comparison
|
||||
pred_ = pred_ == True
|
||||
|
||||
try:
|
||||
if_op = scf.IfOp(
|
||||
pred_.ir_value(),
|
||||
hasElse=(else_block is not None),
|
||||
results_=result_types,
|
||||
)
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(
|
||||
f"Failed to create scf.IfOp \n\t\tpred={pred_}: type : {type(pred_)}"
|
||||
) from e
|
||||
return if_op
|
||||
|
||||
def then_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
|
||||
)
|
||||
return then_block(*flat_args)
|
||||
|
||||
region_builders = [then_builder]
|
||||
|
||||
if else_block is not None:
|
||||
|
||||
def else_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
)
|
||||
)
|
||||
return else_block(*flat_args)
|
||||
|
||||
region_builders.append(else_builder)
|
||||
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="if",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_yield_args,
|
||||
mix_iter_arg_names=mix_yield_arg_names,
|
||||
create_op_func=create_if_op,
|
||||
region_builders=region_builders,
|
||||
)
|
||||
|
||||
|
||||
def _while_execute_dynamic(
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
):
|
||||
"""
|
||||
Create and return an SCF WhileOp for dynamic loops.
|
||||
Generate the dynamic loop body using SCF WhileOp.
|
||||
|
||||
Args:
|
||||
while_before_block: Function that returns (condition, updated_values)
|
||||
while_after_block: Function that returns updated values
|
||||
used_args: Additional arguments used in the loop body
|
||||
yield_args: Values that are updated in the loop
|
||||
|
||||
See create_while_function in ast_preprocessor.py for details on the input structure.
|
||||
"""
|
||||
log().debug("_while_execute_dynamic")
|
||||
while_op_type_name = "while"
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_while_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
# Create the while operation with the types from yield_args
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
try:
|
||||
while_op = scf.WhileOp(result_types, dyn_yield_ops)
|
||||
while_op.before.blocks.append(*result_types)
|
||||
while_op.after.blocks.append(*result_types)
|
||||
log().debug("[%s]", while_op)
|
||||
return while_op
|
||||
except Exception as e:
|
||||
yield_ops = "\n".join(
|
||||
f"\t\t{i} => {d} : type : {type(d)}"
|
||||
for i, d in enumerate(dyn_yield_ops)
|
||||
)
|
||||
raise DSLRuntimeError(
|
||||
f"Failed to create scf.WhileOp with yield_ops:\n{yield_ops}"
|
||||
) from e
|
||||
|
||||
def before_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
):
|
||||
# Build the before (condition) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
)
|
||||
|
||||
log().debug("before block args: %s", flat_args)
|
||||
|
||||
cond, before_results = while_before_block(*flat_args)
|
||||
|
||||
if not isinstance(before_results, (list, ir.OpResultList)):
|
||||
before_results = [before_results]
|
||||
|
||||
log().debug("cond [%s]", cond)
|
||||
log().debug(
|
||||
"before_results [%s]",
|
||||
before_results,
|
||||
)
|
||||
|
||||
return cond, before_results
|
||||
|
||||
def before_block_terminator(cond_and_results):
|
||||
# Generate a condition op instead of yield op
|
||||
cond = cond_and_results[0]
|
||||
before_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
cond_and_results[1]
|
||||
)
|
||||
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
[cond], while_op_type_name
|
||||
)
|
||||
ir_cond = ir_cond_list[0]
|
||||
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
before_result_list, while_op_type_name
|
||||
)
|
||||
log().debug(
|
||||
"creating scf.ConditionOp with [%s], [%s]",
|
||||
ir_cond,
|
||||
ir_results_list,
|
||||
)
|
||||
scf.ConditionOp(ir_cond, ir_results_list)
|
||||
|
||||
def after_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
):
|
||||
# Build the after (body) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
)
|
||||
|
||||
log().debug("after block args: %s", flat_args)
|
||||
|
||||
after_results = while_after_block(*flat_args)
|
||||
|
||||
if not isinstance(after_results, (list, ir.OpResultList)):
|
||||
after_results = [after_results]
|
||||
|
||||
log().debug(
|
||||
"after_results [%s]",
|
||||
after_results,
|
||||
)
|
||||
|
||||
return after_results
|
||||
|
||||
# Call the universal SCF executor with two region builders
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name=while_op_type_name,
|
||||
used_args=used_args,
|
||||
mix_iter_args=yield_args,
|
||||
mix_iter_arg_names=yield_arg_names,
|
||||
create_op_func=create_while_op,
|
||||
region_builders=[before_block_builder, after_block_builder],
|
||||
block_term_op_builder={
|
||||
before_block_builder: before_block_terminator
|
||||
}, # Only customize the before block
|
||||
)
|
||||
3
python/CuTeDSL/requirements.txt
Normal file
3
python/CuTeDSL/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl=4.0.0.dev1
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.9.2'
|
||||
this.__version__ = '4.0.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
|
||||
@@ -111,6 +111,7 @@
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
|
||||
@@ -1,3 +1,34 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
@@ -8,4 +39,3 @@ def lazy_import(mod_name: str) -> Any:
|
||||
return getattr(module, name)
|
||||
|
||||
return Lazy()
|
||||
|
||||
|
||||
@@ -193,3 +193,4 @@ class CUDAEventProfiler:
|
||||
flops_ += m * n * batch_count * 2
|
||||
|
||||
return flops_
|
||||
|
||||
|
||||
@@ -75,15 +75,10 @@ audit_csv_runtime_fields = [
|
||||
]
|
||||
|
||||
def hash_cutlass_string(input_string):
|
||||
# Regex pattern to match instruction shape
|
||||
instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc.
|
||||
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
|
||||
# Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16')
|
||||
output = re.sub(instruction_shape_pattern, "", input_string)
|
||||
|
||||
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
||||
output = re.sub(mma_cluster_shape_pattern, "", output)
|
||||
output = re.sub(mma_cluster_shape_pattern, "", input_string)
|
||||
|
||||
return output
|
||||
|
||||
@@ -288,7 +283,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "101a", "120a"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -300,23 +295,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
#
|
||||
|
||||
sm100_mma_data_type_general = [
|
||||
'x16gemm_f16_f16_f16_f16_f16',
|
||||
'x16gemm_f16_f16_f16_void_f16',
|
||||
'x16gemm_f16_f16_f32_f16_f16',
|
||||
'x8tf32gemm_f32_f32_f32_f32_f32',
|
||||
'x16bf16gemm_f32_f32_f32_f32_f32',
|
||||
'gemm_f16_f16_f16_f16_f16',
|
||||
'gemm_f16_f16_f16_void_f16',
|
||||
'gemm_f16_f16_f32_f16_f16',
|
||||
'tf32gemm_f32_f32_f32_f32_f32',
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'x32gemm_f4_f4_f32_f32_f32',
|
||||
'x32gemm_f6_f6_f32_f32_f32',
|
||||
'x32gemm_f8_f8_f32_f32_f32',
|
||||
'gemm_f4_f4_f32_f32_f32',
|
||||
'gemm_f6_f6_f32_f32_f32',
|
||||
'gemm_f8_f8_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_data_type_mergeable = [
|
||||
'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'x32gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'x32gemm_e3m2_e3m2_f32_f32_f32',
|
||||
'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification
|
||||
'gemm_e2m1_e2m1_f32_f32_f32',
|
||||
'gemm_e3m2_e3m2_f32_f32_f32',
|
||||
]
|
||||
|
||||
sm100_mma_cluster_size = [
|
||||
@@ -331,22 +326,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ntn'
|
||||
]
|
||||
|
||||
sm100_mma_instruction_shape = [
|
||||
# [0] .1CTA, General
|
||||
['64x128', '128x128', '128x256'],
|
||||
# [1] .2CTA, General
|
||||
['128x128', '256x128', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
|
||||
#
|
||||
# Block Scale Gemm
|
||||
@@ -354,19 +342,19 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_data_type_base = [
|
||||
# runtime datatypes
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type_mergeable = [
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable
|
||||
@@ -377,56 +365,43 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
# regex list must be in kernel procedural name order
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch == "100a":
|
||||
if arch == "100a" or arch == "100f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "101a":
|
||||
elif arch == "101a" or arch == "101f":
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch == "120a":
|
||||
elif arch == "120a" or arch == "120f":
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
[ '128x128' ]
|
||||
]
|
||||
|
||||
# sm120 MMA instruction shapes
|
||||
blockscaled_sm120_mma_instruction_shapes = [
|
||||
[ 's16x8x64gemm',
|
||||
's16x8x32gemm'
|
||||
]
|
||||
]
|
||||
|
||||
# Restrict to two layouts to reduce L0 build and test time.
|
||||
blockscaled_sm120_mma_layouts = [ 'tn' ]
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_instruction_shapes[0], blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm101a, sm120a"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
raise Exception(error_message)
|
||||
|
||||
# Statically encoded kernels are still added to generated_kernels
|
||||
@@ -445,14 +420,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
]
|
||||
# Restrict to two layouts to reduce L1 build and test time.
|
||||
sm100_mma_layouts = ['tnt', 'ntn']
|
||||
sm100_mma_instruction_shape = [
|
||||
# .1CTA
|
||||
['64x128', '128x128', '128x256'],
|
||||
# .2CTA
|
||||
['128x128', '256x128', '256x256']
|
||||
]
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
||||
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
||||
block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
||||
@@ -463,15 +432,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
block_scaled_instruction_shape = [
|
||||
# .1CTA
|
||||
['128x128', '128x192', '128x256'],
|
||||
# .2CTA
|
||||
['256x128', '256x192', '256x256'],
|
||||
]
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
|
||||
@@ -183,10 +183,7 @@ class GemmOperation:
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
if self.is_3x:
|
||||
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
else:
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
|
||||
|
||||
inst_shape += math_op_string
|
||||
|
||||
@@ -194,7 +191,9 @@ class GemmOperation:
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
short_math_name = self.short_math_name() if not self.is_3x else ""
|
||||
|
||||
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
||||
|
||||
# Generates a string representing the MMA instruction.
|
||||
def extended_name(self):
|
||||
@@ -337,18 +336,36 @@ class GemmOperation:
|
||||
def opcode_class_name(self):
|
||||
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
|
||||
def get_collective_tile_shape(self):
|
||||
"""
|
||||
Get the tile shape passed to the collective builder.
|
||||
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
||||
"""
|
||||
is_sm100_kernel = (self.arch == 100)
|
||||
if not is_sm100_kernel:
|
||||
return self.tile_description.tile_shape
|
||||
|
||||
opcode_class_main = self.tile_description.math_instruction.opcode_class
|
||||
instruction_shape = self.tile_description.math_instruction.instruction_shape
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
|
||||
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
return (tile_shape_m, tile_shape_n, tile_shape_k)
|
||||
|
||||
# Generates the full kernel function name
|
||||
def procedural_name(self):
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
if self.arch >= 90:
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
|
||||
tile_shape = self.get_collective_tile_shape()
|
||||
return kernel_name_template.format(
|
||||
p = self.prefix,
|
||||
ar = self.arch,
|
||||
op = opcode_class_name,
|
||||
ex = self.extended_name_3x(),
|
||||
ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "",
|
||||
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
|
||||
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
|
||||
l = self.tile_description.stages,
|
||||
s = self.layout_name_3x(),
|
||||
@@ -920,28 +937,8 @@ ${compile_guard_end}
|
||||
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
||||
cluster_m = operation.tile_description.cluster_shape[0]
|
||||
cluster_n = operation.tile_description.cluster_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = tile_shape
|
||||
|
||||
# account for static/dynamic cluster shapes
|
||||
cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0]
|
||||
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
||||
|
||||
|
||||
# Shape passed to epilogue builder
|
||||
is_sm100_kernel = (operation.arch == 100)
|
||||
if is_sm100_kernel:
|
||||
cta_m_per_mma_instruction = 2 if "2sm" in operation.procedural_name() else 1
|
||||
if cluster_m <= 0:
|
||||
cta_m = cta_m // cta_m_per_mma_instruction
|
||||
|
||||
if opcode_class_main in [OpcodeClass.TensorOp
|
||||
, OpcodeClass.BlockScaledTensorOp
|
||||
, OpcodeClass.SparseTensorOp
|
||||
]:
|
||||
tile_shape_m = instruction_shape[0]
|
||||
tile_shape_n = instruction_shape[1]
|
||||
|
||||
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
|
||||
|
||||
# stage count set to zero indicates builder automatic stage selection
|
||||
if operation.tile_description.stages > 0:
|
||||
|
||||
@@ -1003,14 +1003,11 @@ class ConvOperation3x:
|
||||
math_op = self.tile_description.math_instruction.math_operation
|
||||
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
||||
|
||||
inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape))
|
||||
inst_shape += math_op_string
|
||||
|
||||
if self.tile_description.math_instruction.element_a != self.A.element and \
|
||||
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
||||
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
||||
|
||||
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, ConvKindNames[self.conv_kind])
|
||||
return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind])
|
||||
|
||||
def extended_name(self):
|
||||
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
|
||||
@@ -5997,8 +5994,8 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
|
||||
|
||||
math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc)
|
||||
|
||||
valid_types_for_d = [DataType.f32]
|
||||
valid_types_for_c = [DataType.f32]
|
||||
valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2]
|
||||
valid_types_for_c = copy.deepcopy(valid_types_for_d)
|
||||
|
||||
tile_descriptions = generate_tile_descriptions_sm90(
|
||||
math_instructions=math_instructions,
|
||||
@@ -6009,6 +6006,12 @@ def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version):
|
||||
math_inst = tile_desc.math_instruction
|
||||
data_types = []
|
||||
|
||||
# Limit C/D types to avoid a giant number of instantiations.
|
||||
# A typical use case for mixed dtype in DL is weight quantization (tensor A),
|
||||
# therefore we can limit the output type to that of activation (tensor B).
|
||||
valid_types_for_c = [math_inst.element_b]
|
||||
valid_types_for_d = [math_inst.element_b]
|
||||
|
||||
for c_type, d_type in product(valid_types_for_c, valid_types_for_d):
|
||||
data_types.append(
|
||||
generate_data_types_from_math_instruction(
|
||||
@@ -6791,6 +6794,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default
|
||||
]
|
||||
@@ -6838,6 +6846,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -6937,6 +6950,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default
|
||||
]
|
||||
@@ -7090,6 +7108,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -7247,6 +7270,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -7456,6 +7484,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -7916,6 +7949,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[2,1,1],
|
||||
[1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -7985,6 +8025,12 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8138,6 +8184,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -8211,6 +8264,13 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8417,6 +8477,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
[2,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -8537,6 +8604,13 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
[4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8689,6 +8763,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -8788,6 +8867,11 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -8925,6 +9009,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -8953,6 +9040,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9044,6 +9134,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9072,6 +9165,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9163,6 +9259,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9191,6 +9290,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9287,6 +9389,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9319,6 +9424,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9417,6 +9525,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_1sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9476,6 +9587,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in sm100_cluster_shape_2sm:
|
||||
if 101 in manifest.compute_capabilities :
|
||||
if cluster_shape == [4,4,1] :
|
||||
continue
|
||||
multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
@@ -9578,6 +9692,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1], [1,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
@@ -9612,6 +9732,12 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9658,6 +9784,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK
|
||||
]
|
||||
@@ -9726,6 +9858,12 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9809,6 +9947,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [
|
||||
[1,2,1], [2,1,1], [1,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.StreamK,
|
||||
]
|
||||
@@ -9861,6 +10005,12 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1], [2,2,1], [2,4,1], [4,1,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
@@ -9960,6 +10110,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
# tile_descriptions is a 2-level list.
|
||||
# Each inner list is for each cluster shape.
|
||||
for math_inst, output_type in math_instructions_w_output_1sm:
|
||||
@@ -10023,6 +10176,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -10103,6 +10258,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_1sm)
|
||||
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -10166,6 +10323,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
data_types_and_instruction_shapes_2sm)
|
||||
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]]
|
||||
if 101 in manifest.compute_capabilities :
|
||||
cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]]
|
||||
|
||||
for math_inst, output_type in math_instructions_w_output_2sm:
|
||||
tile_descriptions = []
|
||||
@@ -10629,6 +10788,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
# Dense Gemm
|
||||
#
|
||||
architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',]
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
@@ -10636,7 +10797,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
|
||||
GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
# grouped GEMM
|
||||
@@ -10657,7 +10819,8 @@ def GenerateSM100(manifest, cuda_version):
|
||||
#
|
||||
GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
if '100f' not in architectures and '101f' not in architectures:
|
||||
GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version)
|
||||
GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version)
|
||||
|
||||
@@ -11166,7 +11329,7 @@ if __name__ == "__main__":
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
blackwell_enabled_arch = any(arch in ["100a", "101a", "120a"] for arch in archs)
|
||||
blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
@@ -523,10 +523,14 @@ class Manifest:
|
||||
arch_conditional_cc = [
|
||||
'90a',
|
||||
'100a',
|
||||
'100f',
|
||||
'101a',
|
||||
'120a'
|
||||
'101f',
|
||||
'120a',
|
||||
'120f'
|
||||
]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures]
|
||||
architectures = [x if x not in arch_conditional_cc else x.split('f')[0] for x in architectures]
|
||||
|
||||
self.compute_capabilities = [int(x) for x in architectures]
|
||||
|
||||
|
||||
@@ -375,6 +375,13 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
|
||||
for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
|
||||
|
||||
# generator can stamp out duplicate kernels, because it doesn't explicitly set instruction
|
||||
# shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using
|
||||
# the auto kernel schedule.
|
||||
|
||||
math_inst_stub = copy.deepcopy(math_inst)
|
||||
math_inst_stub.instruction_shape = [0, 0, 0]
|
||||
|
||||
tile_desc = TileDescription(
|
||||
threadblock_shape=[
|
||||
math_inst.instruction_shape[0] * mma_mul[0],
|
||||
@@ -383,7 +390,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
],
|
||||
stages=0,
|
||||
warp_count=[4, 1, 1],
|
||||
math_instruction=math_inst,
|
||||
math_instruction=math_inst_stub,
|
||||
min_compute=90,
|
||||
max_compute=90,
|
||||
cluster_shape=cluster_size)
|
||||
@@ -551,6 +558,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
b_type_size = DataTypeSize[data_types["b_type"]]
|
||||
if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
schedules = []
|
||||
stream_k_schedules = []
|
||||
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
||||
if a_type_size > b_type_size:
|
||||
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
||||
@@ -579,7 +587,11 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, []
|
||||
stream_k_schedules.append([
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
epilogue_schedule
|
||||
])
|
||||
return schedules, stream_k_schedules
|
||||
|
||||
if not is_aligned and not is_blockwise(gemm_kind):
|
||||
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.9.2',
|
||||
version='4.0.0',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.9.2',
|
||||
version='4.0.0',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user