diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp index d3b2d0880..71d342e06 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -399,9 +399,33 @@ public: ptr_C_l = params.ptr_C[l_coord]; } + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) @@ -572,12 +596,7 @@ public: can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - - bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); - if (!fusion_implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); - } - return fusion_implementable; + return true; } diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 063e8fb1c..a300ea4a0 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -8046,8 +8046,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule]] + if (data_type["sfd_type"]["type"] == DataType.void): + nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule]) + fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule]) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) @@ -8170,8 +8174,12 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) - nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] - fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule]] + if (data_type["sfd_type"]["type"] == DataType.void): + nvfp4_schedules.append([nvfp4_kernel_schedule, epi_nosmem_schedule]) + fp4_schedules.append([fp4_kernel_schedule, epi_nosmem_schedule]) + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: