From 3f5bafb326e834b69b16fa6d10bc85b644dc1805 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Mon, 19 Jan 2026 23:27:34 -0800 Subject: [PATCH] =?UTF-8?q?[Cutlass=20profiler]=20Fix=20SM100=20FP8=20nosm?= =?UTF-8?q?em=20epilogue=20shape=5Fdiv=20=E2=80=9CDivisibility=20Condition?= =?UTF-8?q?=E2=80=9D=20for=20non=E2=80=91multiple=E2=80=91of=E2=80=9164=20?= =?UTF-8?q?N=20tiles=20(#2946)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * . * . * . * . * . * . * . --- python/cutlass_library/generator.py | 30 ++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 3e13a430b..350ebef53 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -7488,6 +7488,17 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin TileSchedulerType.Default ] + # Some SM100 NoSmem epilogue instantiations rely on CUTE's shape_div, which enforces a compile-time + # divisibility condition between CTA N and the epilogue N tile. Keep this conservative and scoped: + # only apply the divisibility filter for selected common (c_type, d_type) pairs. + # + # Map (c_type, d_type) -> required divisor for CTA N when CTA N > divisor. + # (If CTA N <= divisor, the epilogue N tile equals CTA N and is always divisible.) + _sm100_epilogue_tile_n_divisibility = { + (DataType.void, DataType.f16): 64, + (DataType.void, DataType.bf16): 64, + } + # 1xSM MMA kernels for math_inst in math_instructions_1sm: tile_descriptions = [] @@ -7607,7 +7618,24 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin kernel_schedule = KernelScheduleType.WarpSpecialized1SmSm100 epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized1Sm - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + + # SM100 NoSmem epilogue uses EpilogueTileAuto with N-tile = min(64, cta_n). + # CUTE's shape_div then requires a compile-time divisibility condition between cta_n and 64. + # Only instantiate kernels where cta_n <= 64 or cta_n is an exact multiple of 64 to avoid + # violating that "Divisibility Condition" static_assert. + filtered_tile_descriptions = [] + for tile_description in tile_descriptions: + div_n = _sm100_epilogue_tile_n_divisibility.get((data_type["c_type"], data_type["d_type"])) + if div_n is not None: + cta_n = tile_description.threadblock_shape[1] + if cta_n > div_n and (cta_n % div_n != 0): + continue + filtered_tile_descriptions.append(tile_description) + + if not filtered_tile_descriptions: + continue + + CreateGemmUniversal3xOperator(manifest, layouts, filtered_tile_descriptions, data_type, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)