diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index 99b1323af..8c32c3220 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -199,9 +199,11 @@ struct CollectiveBuilder< using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{})); using SmemLayoutAtomSFA = decltype(make_layout( sSFA_shape{}, sSFA_stride{})); - using sSFB_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFBTileShape_N = Int(TileShape_MNK{}), 128)>; + + using sSFB_shapeN = decltype(prepend(sSFBTileShape_N{} / Blk_MN{}, mnBasicBlockShape{})); using sSFB_strideN = sSF_strideMN; - using sSFB_strideK = decltype(prepend(make_stride(Int{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFB_strideK = decltype(prepend(make_stride(Int{}, sSFBTileShape_N{} / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); using sSFB_shape = decltype(make_shape( sSFB_shapeN{}, sSF_shapeK{})); using sSFB_stride = decltype(make_stride(sSFB_strideN{}, sSFB_strideK{})); using SmemLayoutAtomSFB = decltype(make_layout( sSFB_shape{}, sSFB_stride{})); diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp index 458ee1af4..bb7f63327 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp @@ -203,11 +203,22 @@ struct CollectiveMma< append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) )); - using SmemLayoutSFB = decltype(make_layout( + using SmemLayoutSFB_ = decltype(make_layout( append(shape(SmemLayoutAtomSFB{}), Int{}), append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) )); + using TileShapeSFB = cute::conditional_t(TileShape{}) < 128, + decltype(cute::make_shape( + shape<0>(TileShape{}), + Int<128>{}, + shape<2>(TileShape{}))), + TileShape>; + + using SmemLayoutSFB = cute::conditional_t(TileShape{}) < 128, + decltype(cute::logical_divide(SmemLayoutSFB_{}, select<1,2>(TileShape{}))), + SmemLayoutSFB_>; + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); @@ -321,7 +332,7 @@ struct CollectiveMma< GmemTiledCopySFB{}, make_tensor(static_cast(nullptr), InternalLayoutSFB{}), SmemLayoutSFB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + make_shape(shape<1>(TileShapeSFB{}), shape<2>(TileShapeSFB{})), _1{})); // No programmatic multicast TMA_A tma_load_a; @@ -417,7 +428,7 @@ struct CollectiveMma< GmemTiledCopySFB{}, tensor_sfb, SmemLayoutSFB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + make_shape(shape<1>(TileShapeSFB{}), shape<2>(TileShapeSFB{})), _1{}); // No programmatic multicast return { @@ -712,10 +723,15 @@ struct CollectiveMma< // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + auto broadcast_n = make_layout( + make_shape(Int(TileShapeSFB{}) / size<1>(TileShape{})>{}, + Int::max()>{}), + make_stride(_0{}, size<1>(TileShapeSFB{}) / size<1>(TileShape{}))); + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gSFB = gSFB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFB = gSFB_nkl(_,_,broadcast_n(n_coord),_,l_coord); // (BLK_N,BLK_K,k) // Partition source and destination tensors for tma copies Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -778,7 +794,8 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template < - class FrgTensorC + class FrgTensorC, + class BlockCoord > CUTLASS_DEVICE void mma(MainloopPipeline pipeline, @@ -787,7 +804,8 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - [[maybe_unused]] Params const& params) { + [[maybe_unused]] Params const& params, + BlockCoord const& blk_coord) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); @@ -796,8 +814,17 @@ struct CollectiveMma< Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) + Tensor sSFB = [&]() { + if constexpr (size<1>(TileShape{}) >= 128) { + return make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + } + else { + Tensor temp = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_SFB_N,BLK_K,PIPE) + auto n = get<1>(blk_coord); + return temp(make_coord(_,n % (size<1>(TileShapeSFB{}) / size<1>(TileShape{}))), _, _); + } + }(); // // Define C accumulators and A/B partitioning diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp index 9cb805188..6a8f9ec8b 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp @@ -198,11 +198,22 @@ struct CollectiveMma< append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) )); - using SmemLayoutSFB = decltype(make_layout( + using SmemLayoutSFB_ = decltype(make_layout( append(shape(SmemLayoutAtomSFB{}), Int{}), append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) )); + using TileShapeSFB = cute::conditional_t(TileShape{}) < 128, + decltype(cute::make_shape( + shape<0>(TileShape{}), + Int<128>{}, + shape<2>(TileShape{}))), + TileShape>; + + using SmemLayoutSFB = cute::conditional_t(TileShape{}) < 128, + decltype(cute::logical_divide(SmemLayoutSFB_{}, select<1,2>(TileShape{}))), + SmemLayoutSFB_>; + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); @@ -307,7 +318,7 @@ struct CollectiveMma< GmemTiledCopySFB{}, make_tensor(static_cast(nullptr), LayoutSFB{}), SmemLayoutSFB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + make_shape(shape<1>(TileShapeSFB{}), shape<2>(TileShapeSFB{})), _1{})); // No programmatic multicast TMA_A tma_load_a; @@ -367,7 +378,7 @@ struct CollectiveMma< GmemTiledCopySFB{}, tensor_sfb, SmemLayoutSFB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + make_shape(shape<1>(TileShapeSFB{}), shape<2>(TileShapeSFB{})), _1{}); // No programmatic multicast return { @@ -627,10 +638,14 @@ struct CollectiveMma< // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gSFB = gSFB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + auto broadcast_n = make_layout( + make_shape(Int(TileShapeSFB{}) / size<1>(TileShape{})>{}, + Int::max()>{}), + make_stride(_0{}, size<1>(TileShapeSFB{}) / size<1>(TileShape{}))); + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gSFA = gSFA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gSFB = gSFB_nkl(_,_,broadcast_n(n_coord),_,l_coord); // (BLK_N,BLK_K,k) // Partition source and destination tensors for tma copies Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -693,7 +708,8 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template < - class FrgTensorC + class FrgTensorC, + class BlockCoord > CUTLASS_DEVICE void mma(MainloopPipeline pipeline, @@ -702,7 +718,8 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - [[maybe_unused]] Params const& params) { + [[maybe_unused]] Params const& params, + BlockCoord const& blk_coord) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); @@ -712,7 +729,16 @@ struct CollectiveMma< Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFB = [&]() { + if constexpr (size<1>(TileShape{}) >= 128) { + return make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_N,BLK_K,PIPE) + } + else { + Tensor temp = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // (BLK_SFB_N,BLK_K,PIPE) + auto n = get<1>(blk_coord); + return temp(make_coord(_,n % (size<1>(TileShapeSFB{}) / size<1>(TileShape{}))), _, _); + } + }(); // // Define C accumulators and A/B partitioning diff --git a/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp index a7f481281..29547938c 100644 --- a/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp @@ -614,8 +614,6 @@ struct CollectiveMma< } } - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective template < class FrgTensorC > @@ -627,6 +625,27 @@ struct CollectiveMma< int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { + + auto empty_tuple = make_tuple(_0{}, _0{}, _0{}, _0{}); + mma(pipeline, smem_pipe_read, accum, k_tile_count, + thread_idx, shared_tensors, mainloop_params, empty_tuple); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC, + class BlockCoord + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params, + [[maybe_unused]] BlockCoord& blk_crd) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); diff --git a/include/cutlass/gemm/collective/sm120_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_mma_tma.hpp index 951b17937..173a4e8e4 100644 --- a/include/cutlass/gemm/collective/sm120_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_mma_tma.hpp @@ -439,7 +439,8 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template < - class FrgTensorC + class FrgTensorC, + class BlockCoord > CUTLASS_DEVICE void mma(MainloopPipeline pipeline, @@ -448,7 +449,8 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) { + Params const& mainloop_params, + [[maybe_unused]] BlockCoord& blk_crd) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); diff --git a/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp index bc22419a9..b140998b9 100644 --- a/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp @@ -545,8 +545,6 @@ struct CollectiveMma< } } - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective template < class FrgTensorC > @@ -558,6 +556,25 @@ struct CollectiveMma< int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { + auto empty_tuple = make_tuple(_0{}, _0{}, _0{}, _0{}); + mma(pipeline, smem_pipe_read, accum, k_tile_count, thread_idx, shared_tensors, mainloop_params, empty_tuple); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC, + class BlockCoord + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params, + [[maybe_unused]] BlockCoord& blk_crd) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 3a5149d6e..3fec578ba 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -168,10 +168,28 @@ public: static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; + /// Register requirement for Load and Math WGs + static constexpr int RegsPerThread = + 2 * size<0>(TileShape{}) * size<1>(TileShape{}) / NumMmaThreads * + sizeof(ElementAccumulator) / sizeof(uint32_t); + + // Detect if this is SM120 blockscaled kernel which hits low register pressure + // on smaller tiles + template + struct IsSm120BlockScaled : cute::false_type {}; + + template + struct IsSm120BlockScaled> + : cute::true_type {}; + + static constexpr bool IsLowRegisterPressure = IsSm120BlockScaled::value && (RegsPerThread <= 64); + /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; + static constexpr bool IsSm120Family = cute::is_same_v; + // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; @@ -606,7 +624,9 @@ public: auto k_tile_count = size<3>(gA_mkl); if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); + if constexpr (!IsLowRegisterPressure) { + cutlass::arch::warpgroup_reg_dealloc(); + } if (producer_warp_role == ProducerWarpRole::Scheduler) { // GroupScheduler requires a producer warp to iterate over the group infos and push @@ -882,7 +902,9 @@ public: } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - cutlass::arch::warpgroup_reg_alloc(); + if constexpr (!IsLowRegisterPressure) { + cutlass::arch::warpgroup_reg_alloc(); + } // Index of warp group within consumer warp groups int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; @@ -935,15 +957,29 @@ public: if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); + if constexpr (IsSm120Family) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop, + blk_coord + ); + } + else { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + } // Make sure the math instructions are done and free buffers before entering the epilogue collective_mainloop.mma_tail( diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index e1fa1c86b..451e92511 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -167,10 +167,28 @@ public: static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; + /// Register requirement for Load and Math WGs + static constexpr int RegsPerThread = + 2 * size<0>(TileShape{}) * size<1>(TileShape{}) / (NumMmaWarpGroups * NumThreadsPerWarpGroup) * + sizeof(ElementAccumulator) / sizeof(uint32_t); + + // Detect if this is SM120 blockscaled kernel which hits low register pressure + // on smaller tiles + template + struct IsSm120BlockScaled : cute::false_type {}; + + template + struct IsSm120BlockScaled> + : cute::true_type {}; + + static constexpr bool IsLowRegisterPressure = IsSm120BlockScaled::value && (RegsPerThread <= 64); + /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; + static constexpr bool IsSm120Family = cute::is_same_v; + // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; @@ -644,7 +662,9 @@ public: auto k_tile_count = size<3>(gA_mkl); if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); + if constexpr (!IsLowRegisterPressure) { + cutlass::arch::warpgroup_reg_dealloc(); + } if (producer_warp_role == ProducerWarpRole::Scheduler) { // GroupScheduler requires a producer warp to iterate over the group infos and push @@ -920,7 +940,9 @@ public: } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - cutlass::arch::warpgroup_reg_alloc(); + if constexpr (!IsLowRegisterPressure) { + cutlass::arch::warpgroup_reg_alloc(); + } // Index of warp group within consumer warp groups int consumer_warp_group_idx = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1; @@ -975,15 +997,29 @@ public: math_wg_order_barrier.wait(); - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); + if constexpr (IsSm120Family) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop, + blk_coord + ); + } + else { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + } math_wg_order_barrier.arrive(); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index c6a94cbf5..3f669e1b2 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -786,15 +786,29 @@ public: // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); + if constexpr (IsSm120Family) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop, + blk_coord + ); + } + else { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + } // Make sure the math instructions are done and free buffers before entering the epilogue collective_mainloop.mma_tail( diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 1fd9fb0ec..d86235aad 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -849,15 +849,29 @@ public: // Order two Math WG's MMA one after the other, helps hide Epilogue math_wg_order_barrier.wait(); - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - k_tile_count, - warp_group_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); + if constexpr (IsSm120Family) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop, + blk_coord + ); + } + else { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + } // Cue for next Math WG's MMA to start math_wg_order_barrier.arrive(); diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 49b8c26a8..f5c9df283 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -11192,6 +11192,8 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud ] tile_sizes = [ + [128, 32, 128], + [128, 64, 128], [128, 128, 128] ] @@ -11324,12 +11326,20 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio ] tile_sizes_cooperative = [ + [128, 32, 128], + [128, 32, 256], + [128, 64, 128], + [128, 64, 256], [128, 128, 128], [128, 128, 256], [256, 128, 128] ] tile_sizes_pingpong = [ + [128, 32, 128], + [128, 32, 256], + [128, 64, 128], + [128, 64, 256], [128, 128, 128], [128, 128, 256] ] diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu index 3f229c906..9c88a368d 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf4_mxf4_f32_f32.cu @@ -120,4 +120,132 @@ TEST(SM120_Device_Blockscaled_Gemm_mxf4t_mxf4n_f32n_tensor_op_f32, 128x128x256) EXPECT_TRUE(result); } +namespace kernel_2 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_64,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + +} // kernel_2 + +TEST(SM120_Device_Blockscaled_Gemm_mxf4t_mxf4n_f32n_tensor_op_f32, 128x64x256) { + bool result = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(result); +} + +namespace kernel_3 { + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementPairA = cutlass::mx_float4_t; + using ElementPairB = cutlass::mx_float4_t; + + static constexpr int AlignmentA = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentB = 16 * 8 / cutlass::sizeof_bits::value; // Align to 16 bytes. + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using TileShape = Shape<_128,_32,_256>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementPairA, LayoutA, AlignmentA, + ElementPairB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + template + struct dummy { + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + }; + using GemmKernel = typename dummy::GemmKernel; + using Gemm = typename dummy::Gemm; + +} // kernel_3 + +TEST(SM120_Device_Blockscaled_Gemm_mxf4t_mxf4n_f32n_tensor_op_f32, 128x32x256) { + bool result = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(result); +} + #endif // (defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)) diff --git a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu index 25340d189..8c4285126 100644 --- a/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu +++ b/test/unit/gemm/device/sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_nvf4_nvf4_f32_nvf4_group_gemm_fusion.cu @@ -132,6 +132,77 @@ TEST(SM120_Device_Gemm_e2m1t_e2m1n_e2m1t_tensorop_f32_epilogue_VS16_group_pingpo EXPECT_TRUE(pass); } +TEST(SM120_Device_Gemm_e2m1t_e2m1n_e2m1t_tensorop_f32_epilogue_VS16_group_pingpong, row_sf_128x32x128) { + using ElementInput = float_e2m1_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::float_e2m1_t; + using ElementCompute = float; + using ElementAccumulator = float; + using ElementSF = cutlass::float_ue4m3_t; + using ElementSFD = ElementSF; + using ElementAccumulator = float; + using GmemLayoutA = cutlass::layout::RowMajor; + using GmemLayoutB = cutlass::layout::ColumnMajor; + using GmemLayoutC = cutlass::layout::RowMajor; + constexpr int SFVectorSize = 16; + using TileShape_MNK = Shape<_128,_32,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // + // Construct CollectiveEpilogue + // + + constexpr int OutputSFVectorSize = SFVectorSize; + // D = alpha * acc + beta * C + // With Row-major BlockScaleFactor generation. + using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< + OutputSFVectorSize, + ElementD, + ElementCompute, + ElementSFD, GmemLayoutC, + ElementC>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, GmemLayoutC *, AlignmentC, + ElementD, GmemLayoutC *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassBlockScaledTensorOp, + ElementA, GmemLayoutA *, AlignmentA, + ElementB, GmemLayoutB *, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); + EXPECT_TRUE(pass); +} TEST(SM120_Device_Gemm_e2m1t_e2m1n_e2m1t_tensorop_f32_epilogue_VS16_group_pingpong, silu_row_sf) {