mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-21 21:38:18 +00:00
[Cutlass profiler] Fix SM100 FP8 nosmem epilogue shape_div “Divisibility Condition” for non‑multiple‑of‑64 N tiles (#2946)
* . * . * . * . * . * . * .
This commit is contained in:
@@ -7488,6 +7488,17 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin
|
||||
TileSchedulerType.Default
|
||||
]
|
||||
|
||||
# Some SM100 NoSmem epilogue instantiations rely on CUTE's shape_div, which enforces a compile-time
|
||||
# divisibility condition between CTA N and the epilogue N tile. Keep this conservative and scoped:
|
||||
# only apply the divisibility filter for selected common (c_type, d_type) pairs.
|
||||
#
|
||||
# Map (c_type, d_type) -> required divisor for CTA N when CTA N > divisor.
|
||||
# (If CTA N <= divisor, the epilogue N tile equals CTA N and is always divisible.)
|
||||
_sm100_epilogue_tile_n_divisibility = {
|
||||
(DataType.void, DataType.f16): 64,
|
||||
(DataType.void, DataType.bf16): 64,
|
||||
}
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
@@ -7607,7 +7618,24 @@ def GenerateSM100_TensorOp_fp8_UMMA_alignx_gemm(manifest, cuda_version, gemm_kin
|
||||
|
||||
kernel_schedule = KernelScheduleType.WarpSpecialized1SmSm100
|
||||
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized1Sm
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
|
||||
# SM100 NoSmem epilogue uses EpilogueTileAuto with N-tile = min(64, cta_n).
|
||||
# CUTE's shape_div then requires a compile-time divisibility condition between cta_n and 64.
|
||||
# Only instantiate kernels where cta_n <= 64 or cta_n is an exact multiple of 64 to avoid
|
||||
# violating that "Divisibility Condition" static_assert.
|
||||
filtered_tile_descriptions = []
|
||||
for tile_description in tile_descriptions:
|
||||
div_n = _sm100_epilogue_tile_n_divisibility.get((data_type["c_type"], data_type["d_type"]))
|
||||
if div_n is not None:
|
||||
cta_n = tile_description.threadblock_shape[1]
|
||||
if cta_n > div_n and (cta_n % div_n != 0):
|
||||
continue
|
||||
filtered_tile_descriptions.append(tile_description)
|
||||
|
||||
if not filtered_tile_descriptions:
|
||||
continue
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, filtered_tile_descriptions, data_type,
|
||||
[[kernel_schedule, epi_schedule]],
|
||||
tile_schedulers=tile_schedulers, gemm_kind=gemm_kind)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user