mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-04 05:31:17 +00:00
Release v4.0.0 (#2294)
This commit is contained in:
64
python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py
Normal file
64
python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
"""
|
||||
This module provides MLIR GPU Dialect helper functions
|
||||
"""
|
||||
|
||||
|
||||
from ..._mlir import ir
|
||||
from ..._mlir.dialects import gpu, arith, scf
|
||||
from ..._mlir.extras import types as T
|
||||
|
||||
from ..common import *
|
||||
|
||||
# =============================================================================
|
||||
# GPU Dialect Helper functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_async_token():
|
||||
token_ty = gpu.AsyncTokenType.get()
|
||||
token = gpu.wait(token_ty, [])
|
||||
return token
|
||||
|
||||
|
||||
def printf(fmt, *args, threadNumber=-1):
|
||||
"""Generate gpu.printf OP predicated on threadNumber"""
|
||||
type_formats = []
|
||||
for arg in args:
|
||||
ty_format = None
|
||||
if ir.IndexType.isinstance(arg.type):
|
||||
ty_format = "%llu"
|
||||
if ir.IntegerType.isinstance(arg.type):
|
||||
width = ir.IntegerType(arg.type).width
|
||||
if width == 64:
|
||||
ty_format = "%llu"
|
||||
elif width == 32:
|
||||
ty_format = "%d"
|
||||
elif width == 1:
|
||||
ty_format = "%i"
|
||||
if ir.F32Type.isinstance(arg.type):
|
||||
ty_format = "%f"
|
||||
if ty_format is None:
|
||||
raise DSLNotImplemented(arg.type)
|
||||
type_formats.append(ty_format)
|
||||
if threadNumber == -1:
|
||||
gpu.printf(fmt.format(*type_formats) + "\n", args)
|
||||
if threadNumber != -1:
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
predicate = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber)
|
||||
)
|
||||
if_op = scf.IfOp(predicate)
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
gpu.printf(fmt.format(*type_formats) + "\n", args)
|
||||
scf.yield_([])
|
||||
Reference in New Issue
Block a user