From 3f3db08a0a9beb5cda086b355158351afb68405e Mon Sep 17 00:00:00 2001 From: Nandor Licker Date: Fri, 17 Apr 2026 03:47:47 +0300 Subject: [PATCH] Add support for empty dataclass arguments (#3152) A dataclass with no fields exposed a bug in `extract_dataclass_members`: ``` @dataclass class Dummy: pass ``` The type/return path was inconsistent. This PR fixes the function to support empty dataclasses, which are useful in unions. --- .../cutlass/base_dsl/utils/tree_utils.py | 6 +- test/examples/CuTeDSL/test_dataclasses.py | 72 +++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 test/examples/CuTeDSL/test_dataclasses.py diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py index 6e1e59f08..86fb2cd46 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py @@ -192,7 +192,7 @@ class Leaf: # ============================================================================= -def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: +def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any], list[Any]]: """ Extract non-method, non-function attributes from a dataclass instance. @@ -200,7 +200,7 @@ def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: x: A dataclass instance Returns: - tuple: (field_names, field_values) lists + tuple: (field_names, field_values, constexpr_fields) lists """ fields = [field.name for field in dataclasses.fields(x)] @@ -213,7 +213,7 @@ def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: ) if not fields: - return [], [] + return [], [], [] # record constexpr fields members = [] diff --git a/test/examples/CuTeDSL/test_dataclasses.py b/test/examples/CuTeDSL/test_dataclasses.py new file mode 100644 index 000000000..8e33fbc82 --- /dev/null +++ b/test/examples/CuTeDSL/test_dataclasses.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from dataclasses import dataclass + +import pytest +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + + +@dataclass +class A: + pass + + +@dataclass +class B: + pass + + +@cute.kernel +def _test_empty_dataclass_kernel(out: cute.Tensor, tag: A | B): + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + match tag: + case A(): + out[0] = 0 + case B(): + out[0] = 1 + + +@cute.jit +def _test_empty_dataclass_host(out: cute.Tensor, tag: A | B): + _test_empty_dataclass_kernel(out, tag).launch(grid=[1, 1, 1], block=[1, 1, 1]) + + +@pytest.mark.parametrize("tag,expected", [(A(), 0), (B(), 1)]) +def test_empty_dataclass_union(tag, expected): + out = torch.zeros(1, device="cuda", dtype=torch.int32) + out_cute = from_dlpack(out).mark_layout_dynamic() + compiled_fn = cute.compile(_test_empty_dataclass_host, out_cute, tag) + compiled_fn(out_cute, tag) + torch.cuda.synchronize() + assert out.item() == expected