Fix for blockwise gg nosmem epi && no sfd with nosmem GG epilogues

This commit is contained in:
Haicheng Wu
2026-06-20 14:27:27 -07:00
parent f93e1cea92
commit bed71224ed
2 changed files with 39 additions and 12 deletions

View File

@@ -399,9 +399,33 @@ public:
ptr_C_l = params.ptr_C[l_coord];
}
auto [stride_c, stride_d] = [&, l = l_coord]() {
if constexpr (!cute::is_same_v<InternalStrideC, StrideC>) {
// If grouped gemm
if (epilogue_op.is_source_needed()) {
return make_tuple(
detail::get_epilogue_stride<DispatchPolicy>(params.dC[l]),
detail::get_epilogue_stride<DispatchPolicy>(params.dD[l])
);
}
else {
return make_tuple(
InternalStrideC{},
detail::get_epilogue_stride<DispatchPolicy>(params.dD[l])
);
}
}
else {
return make_tuple(
detail::get_epilogue_stride<DispatchPolicy>(params.dC),
detail::get_epilogue_stride<DispatchPolicy>(params.dD)
);
}
}();
// Represent the full output tensor, slice to get the tile this CTA is responsible for
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L)
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L)
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
@@ -572,12 +596,7 @@ public:
can_implement(
[[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}
return fusion_implementable;
return true;
}

View File

@@ -8046,8 +8046,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
if (data_type["sfd_type"]["type"] == DataType.void):
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
)
@@ -8170,8 +8174,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
if (data_type["sfd_type"]["type"] == DataType.void):
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
if isFp4: