Files
cutlass/test/python/CuTeDSL/test_struct_in_if.py
George Karpenkov 1732ed7da3 [CuTeDSL] Make @cute.struct instances flattenable across scf.if / scf.while (#3270)
Fixes https://github.com/NVIDIA/cutlass/issues/3268

A `@cute.struct` instance captured into an `scf.if` branch or `scf.while`
body fails the DSL trace with:

    DSLRuntimeError: The 'if' statement encountered a user-defined Python
    object, which cannot be automatically converted into an dynamic
    expression.

This blocks the natural warp-specialization pattern, where each
`if warp_idx == <role>:` branch reads its tile from a shared storage
struct.

A struct instance is fully described by its `base` pointer (already
DynamicExpression-aware via `_Pointer`); every field instance is
re-derived from `base + static offsets` on construction. Implement the
DynamicExpression protocol on each decorated class by forwarding
`__get_mlir_types__` / `__extract_mlir_values__` to `base`, and
`__new_from_mlir_values__` to a fresh decorator invocation that
re-derives the fields from a rebuilt base pointer.

Tested in Docker on cutlass-dsl 4.5.1 with six new unit tests in
test/python/CuTeDSL/test_struct_in_if.py covering:
  * the original failing case (storage.get_tensor inside dynamic if),
  * regression: plain non-branched struct usage still works,
  * nested struct (struct-of-struct) inside a dynamic if,
  * if/else with both branches accessing the struct,
  * if/elif/elif/else (the actual warp-specialization shape),
  * scf.while body capturing the struct.
2026-05-29 08:34:48 +08:00

193 lines
6.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
Unit tests for using @cute.struct instances across DSL control flow.
A struct instance is implicitly captured into the value set of an
`scf.if` whenever its fields are accessed inside the branch (the natural
warp-specialization pattern). For this to work, the struct must
implement the DynamicExpression protocol so the DSL can flatten/unflatten
it across the branch boundary.
"""
import unittest
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass import Float32, Int32
from cutlass.cute.runtime import make_fake_tensor
@cute.struct
class _OneTile:
sA: cute.struct.MemRange[Float32, 128]
@cute.struct
class _Inner:
sX: cute.struct.MemRange[Float32, 64]
@cute.struct
class _Outer:
inner: _Inner
sY: cute.struct.MemRange[Float32, 64]
class TestStructInIf(unittest.TestCase):
def test_get_tensor_inside_dynamic_if(self):
"""The original failing case: storage.<field>.get_tensor() inside
a dynamic if. Used to raise DSLRuntimeError because the struct
instance was not flattenable across the scf.if boundary."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_OneTile)
sA_layout = cute.make_layout((128,), stride=(1,))
if tidx == Int32(0):
sA = storage.sA.get_tensor(sA_layout)
sA[0] = gA[0]
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
def test_struct_use_outside_if_still_works(self):
"""Regression: plain struct usage (no dynamic if) must keep
working — i.e. the DynamicExpression-protocol injection must not
break the existing non-branched path."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_OneTile)
sA_layout = cute.make_layout((128,), stride=(1,))
sA = storage.sA.get_tensor(sA_layout)
sA[tidx] = gA[tidx]
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
def test_nested_struct_inside_dynamic_if(self):
"""Nested struct (struct-of-struct) is also captured into the
branch — verify the flatten round-trip still rebuilds the inner
field correctly."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_Outer)
inner_layout = cute.make_layout((64,), stride=(1,))
if tidx == Int32(0):
sX = storage.inner.sX.get_tensor(inner_layout)
sX[0] = gA[0]
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=1024)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
def test_if_else_branches(self):
"""Both branches of an if/else access struct fields. Covers the
case where the struct is live on both edges of the scf.if."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_OneTile)
sA_layout = cute.make_layout((128,), stride=(1,))
if tidx == Int32(0):
sA = storage.sA.get_tensor(sA_layout)
sA[0] = gA[0]
else:
sA = storage.sA.get_tensor(sA_layout)
sA[1] = gA[1]
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(32, 1, 1), smem=512)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
def test_if_elif_else_warp_spec_pattern(self):
"""The actual warp-specialization shape: an if/elif/elif chain
where each role accesses a different field of the storage."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
warp = tidx // Int32(32)
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_Outer)
inner_layout = cute.make_layout((64,), stride=(1,))
outer_layout = cute.make_layout((64,), stride=(1,))
if warp == Int32(0):
sX = storage.inner.sX.get_tensor(inner_layout)
sX[0] = gA[0]
elif warp == Int32(1):
sY = storage.sY.get_tensor(outer_layout)
sY[0] = gA[1]
else:
sY = storage.sY.get_tensor(outer_layout)
sY[1] = gA[2]
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(128, 1, 1), smem=1024)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
def test_struct_captured_in_while_body(self):
"""scf.while uses the same unpack_to_irvalue machinery as scf.if
for carrying captured Python objects through the loop. Cover
struct access inside a dynamic while body."""
@cute.kernel
def k(gA: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
smem_alloc = utils.SmemAllocator()
storage = smem_alloc.allocate(_OneTile)
sA_layout = cute.make_layout((128,), stride=(1,))
i = Int32(0)
while i < Int32(4):
sA = storage.sA.get_tensor(sA_layout)
sA[i] = gA[i]
i = i + Int32(1)
@cute.jit
def entry(gA: cute.Tensor):
k(gA).launch(grid=(1, 1, 1), block=(1, 1, 1), smem=512)
gA = make_fake_tensor(Float32, (128,), stride=(1,), assumed_align=4)
cute.compile(entry, gA)
if __name__ == "__main__":
unittest.main()