mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-29 19:07:07 +00:00
Fixes https://github.com/NVIDIA/cutlass/issues/3268 A `@cute.struct` instance captured into an `scf.if` branch or `scf.while` body fails the DSL trace with: DSLRuntimeError: The 'if' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression. This blocks the natural warp-specialization pattern, where each `if warp_idx == <role>:` branch reads its tile from a shared storage struct. A struct instance is fully described by its `base` pointer (already DynamicExpression-aware via `_Pointer`); every field instance is re-derived from `base + static offsets` on construction. Implement the DynamicExpression protocol on each decorated class by forwarding `__get_mlir_types__` / `__extract_mlir_values__` to `base`, and `__new_from_mlir_values__` to a fresh decorator invocation that re-derives the fields from a rebuilt base pointer. Tested in Docker on cutlass-dsl 4.5.1 with six new unit tests in test/python/CuTeDSL/test_struct_in_if.py covering: * the original failing case (storage.get_tensor inside dynamic if), * regression: plain non-branched struct usage still works, * nested struct (struct-of-struct) inside a dynamic if, * if/else with both branches accessing the struct, * if/elif/elif/else (the actual warp-specialization shape), * scf.while body capturing the struct.
193 lines
6.8 KiB
Python
193 lines
6.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
#
|
|
# Use of this software is governed by the terms and conditions of the
|
|
# NVIDIA End User License Agreement (EULA), available at:
|
|
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
|
#
|
|
# Any use, reproduction, disclosure, or distribution of this software
|
|
# and related documentation outside the scope permitted by the EULA
|
|
# is strictly prohibited.
|
|
|
|
"""
|
|
Unit tests for using @cute.struct instances across DSL control flow.
|
|
|
|
A struct instance is implicitly captured into the value set of an
|
|
`scf.if` whenever its fields are accessed inside the branch (the natural
|
|
warp-specialization pattern). For this to work, the struct must
|
|
implement the DynamicExpression protocol so the DSL can flatten/unflatten
|
|
it across the branch boundary.
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.utils as utils
|
|
from cutlass import Float32, Int32
|
|
from cutlass.cute.runtime import make_fake_tensor
|
|
|
|
|
|
@cute.struct
|
|
class _OneTile:
|
|
sA: cute.struct.MemRange[Float32, 128]
|
|
|
|
|
|
@cute.struct
|
|
class _Inner:
|
|
sX: cute.struct.MemRange[Float32, 64]
|
|
|
|
|
|
@cute.struct
|
|
class _Outer:
|
|
inner: _Inner
|
|
sY: cute.struct.MemRange[Float32, 64]
|
|
|
|
|
|
class TestStructInIf(unittest.TestCase):
|
|
def test_get_tensor_inside_dynamic_if(self):
|
|
"""The original failing case: storage.<field>.get_tensor() inside
|
|
a dynamic if. Used to raise DSLRuntimeError because the struct
|
|
instance was not flattenable across the scf.if boundary."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_OneTile)
|
|
sA_layout = cute.make_layout((128,), stride=(1,))
|
|
if tidx == Int32(0):
|
|
sA = storage.sA.get_tensor(sA_layout)
|
|
sA[0] = gA[0]
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
def test_struct_use_outside_if_still_works(self):
|
|
"""Regression: plain struct usage (no dynamic if) must keep
|
|
working — i.e. the DynamicExpression-protocol injection must not
|
|
break the existing non-branched path."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_OneTile)
|
|
sA_layout = cute.make_layout((128,), stride=(1,))
|
|
sA = storage.sA.get_tensor(sA_layout)
|
|
sA[tidx] = gA[tidx]
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
def test_nested_struct_inside_dynamic_if(self):
|
|
"""Nested struct (struct-of-struct) is also captured into the
|
|
branch — verify the flatten round-trip still rebuilds the inner
|
|
field correctly."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_Outer)
|
|
inner_layout = cute.make_layout((64,), stride=(1,))
|
|
if tidx == Int32(0):
|
|
sX = storage.inner.sX.get_tensor(inner_layout)
|
|
sX[0] = gA[0]
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=1024)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
def test_if_else_branches(self):
|
|
"""Both branches of an if/else access struct fields. Covers the
|
|
case where the struct is live on both edges of the scf.if."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_OneTile)
|
|
sA_layout = cute.make_layout((128,), stride=(1,))
|
|
if tidx == Int32(0):
|
|
sA = storage.sA.get_tensor(sA_layout)
|
|
sA[0] = gA[0]
|
|
else:
|
|
sA = storage.sA.get_tensor(sA_layout)
|
|
sA[1] = gA[1]
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
def test_if_elif_else_warp_spec_pattern(self):
|
|
"""The actual warp-specialization shape: an if/elif/elif chain
|
|
where each role accesses a different field of the storage."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
warp = tidx // Int32(32)
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_Outer)
|
|
inner_layout = cute.make_layout((64,), stride=(1,))
|
|
outer_layout = cute.make_layout((64,), stride=(1,))
|
|
if warp == Int32(0):
|
|
sX = storage.inner.sX.get_tensor(inner_layout)
|
|
sX[0] = gA[0]
|
|
elif warp == Int32(1):
|
|
sY = storage.sY.get_tensor(outer_layout)
|
|
sY[0] = gA[1]
|
|
else:
|
|
sY = storage.sY.get_tensor(outer_layout)
|
|
sY[1] = gA[2]
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(128, 1, 1), smem=1024)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
def test_struct_captured_in_while_body(self):
|
|
"""scf.while uses the same unpack_to_irvalue machinery as scf.if
|
|
for carrying captured Python objects through the loop. Cover
|
|
struct access inside a dynamic while body."""
|
|
|
|
@cute.kernel
|
|
def k(gA: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
smem_alloc = utils.SmemAllocator()
|
|
storage = smem_alloc.allocate(_OneTile)
|
|
sA_layout = cute.make_layout((128,), stride=(1,))
|
|
i = Int32(0)
|
|
while i < Int32(4):
|
|
sA = storage.sA.get_tensor(sA_layout)
|
|
sA[i] = gA[i]
|
|
i = i + Int32(1)
|
|
|
|
@cute.jit
|
|
def entry(gA: cute.Tensor):
|
|
k(gA).launch(grid=(1, 1, 1), block=(1, 1, 1), smem=512)
|
|
|
|
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
|
|
cute.compile(entry, gA)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|