mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-28 18:37:05 +00:00
Fix for blockwise gg nosmem epi && no sfd with nosmem GG epilogues
This commit is contained in:
@@ -399,9 +399,33 @@ public:
|
||||
ptr_C_l = params.ptr_C[l_coord];
|
||||
}
|
||||
|
||||
auto [stride_c, stride_d] = [&, l = l_coord]() {
|
||||
if constexpr (!cute::is_same_v<InternalStrideC, StrideC>) {
|
||||
// If grouped gemm
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
return make_tuple(
|
||||
detail::get_epilogue_stride<DispatchPolicy>(params.dC[l]),
|
||||
detail::get_epilogue_stride<DispatchPolicy>(params.dD[l])
|
||||
);
|
||||
}
|
||||
else {
|
||||
return make_tuple(
|
||||
InternalStrideC{},
|
||||
detail::get_epilogue_stride<DispatchPolicy>(params.dD[l])
|
||||
);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return make_tuple(
|
||||
detail::get_epilogue_stride<DispatchPolicy>(params.dC),
|
||||
detail::get_epilogue_stride<DispatchPolicy>(params.dD)
|
||||
);
|
||||
}
|
||||
}();
|
||||
|
||||
// Represent the full output tensor, slice to get the tile this CTA is responsible for
|
||||
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N)
|
||||
|
||||
@@ -572,12 +596,7 @@ public:
|
||||
can_implement(
|
||||
[[maybe_unused]] ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
|
||||
bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread);
|
||||
if (!fusion_implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
||||
}
|
||||
return fusion_implementable;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8046,8 +8046,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped)
|
||||
|
||||
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
|
||||
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
|
||||
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
|
||||
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
|
||||
if (data_type["sfd_type"]["type"] == DataType.void):
|
||||
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
|
||||
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind
|
||||
)
|
||||
@@ -8170,8 +8174,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped)
|
||||
|
||||
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]]
|
||||
fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]]
|
||||
nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]]
|
||||
fp4_schedules = [[fp4_kernel_schedule, epi_schedule]]
|
||||
if (data_type["sfd_type"]["type"] == DataType.void):
|
||||
nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule])
|
||||
fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule])
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules
|
||||
, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
if isFp4:
|
||||
|
||||
Reference in New Issue
Block a user