Fix for blockwise gg nosmem epi && no sfd with nosmem GG epilogues

This commit is contained in:
Haicheng Wu
2026-06-20 14:27:27 -07:00
parent f93e1cea92
commit bed71224ed
2 changed files with 39 additions and 12 deletions

View File

@@ -8046,8 +8046,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
if (data_type["sfd_type"]["type"] == DataType.void):
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
@@ -8170,8 +8174,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
if (data_type["sfd_type"]["type"] == DataType.void):
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
if isFp4: