mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-21 23:38:57 +00:00
1269 lines
59 KiB
C++
1269 lines
59 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/detail/collective.hpp"
|
|
#include "cutlass/detail/cluster.hpp"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/pipeline/pipeline.hpp"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
|
|
#include "cutlass/trace.h"
|
|
#include "cutlass/kernel_hardware_info.hpp"
|
|
#include "cutlass/detail/collective.hpp"
|
|
#include "cutlass/detail/sm100_tmem_helper.hpp"
|
|
|
|
#include "cute/algorithm/functional.hpp"
|
|
#include "cute/arch/cluster_sm90.hpp"
|
|
#include "cute/atom/mma_atom.hpp"
|
|
#include "cute/algorithm/gemm.hpp"
|
|
#include "cute/tensor_predicate.hpp"
|
|
#include "cute/numeric/arithmetic_tuple.hpp"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass::gemm::collective {
|
|
using namespace cute;
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// WarpSpecialized Mainloop
|
|
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
|
|
template <
|
|
int Stages,
|
|
int SchedulerPipelineStageCount,
|
|
int AccumulatorPipelineStageCount,
|
|
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
|
|
class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
|
class ElementPairA_,
|
|
class StridePairA_,
|
|
class ElementPairB_,
|
|
class StridePairB_,
|
|
class TiledMma_,
|
|
class GmemTiledCopyPairA_,
|
|
class SmemLayoutAtomPairA_,
|
|
class SmemCopyAtomA_,
|
|
class TransformA_,
|
|
class GmemTiledCopyPairB_,
|
|
class SmemLayoutAtomPairB_,
|
|
class SmemCopyAtomB_,
|
|
class TransformB_>
|
|
struct CollectiveMma<
|
|
MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled<
|
|
Stages,
|
|
SchedulerPipelineStageCount,
|
|
AccumulatorPipelineStageCount,
|
|
ClusterShape>,
|
|
TileShape_,
|
|
ElementPairA_,
|
|
StridePairA_,
|
|
ElementPairB_,
|
|
StridePairB_,
|
|
TiledMma_,
|
|
GmemTiledCopyPairA_,
|
|
SmemLayoutAtomPairA_,
|
|
SmemCopyAtomA_,
|
|
TransformA_,
|
|
GmemTiledCopyPairB_,
|
|
SmemLayoutAtomPairB_,
|
|
SmemCopyAtomB_,
|
|
TransformB_>
|
|
{
|
|
//
|
|
// Type Aliases
|
|
//
|
|
using TiledMma = TiledMma_;
|
|
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
|
|
|
using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled<
|
|
Stages,
|
|
SchedulerPipelineStageCount,
|
|
AccumulatorPipelineStageCount,
|
|
ClusterShape>;
|
|
using TileShape = TileShape_;
|
|
// Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface.
|
|
using TiledMMA_SF = TiledMMA<MMA_Atom<typename TiledMma::MMA_ScaleFactor>,
|
|
Layout<Shape<_1,_1,_1>>,
|
|
Tile<Underscore,Underscore,Underscore>>;
|
|
|
|
static constexpr bool IsDynamicCluster = not cute::is_static_v<ClusterShape>;
|
|
static constexpr int SFVecSize = TiledMma::SFVecSize;
|
|
static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum;
|
|
|
|
CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})),
|
|
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
|
|
|
|
using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{}));
|
|
static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256,
|
|
"Cta N should be one of 128/192/256");
|
|
|
|
using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{})));
|
|
using Sm100BlkScaledConfig = cutlass::detail::Sm100BlockScaledConfig<SFVecSize>;
|
|
using Blk_MN = typename Sm100BlkScaledConfig::Blk_MN;
|
|
static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192;
|
|
static int constexpr CTA_N_SF = cutlass::ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{};
|
|
// Tile shape used for partitioning Scale Factor B.
|
|
// The M-dim does not affect the SFB, so just set it as the original TileShape;
|
|
using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}),
|
|
Int<CTA_N_SF>{} * shape<2>(typename TiledMma::ThrLayoutVMNK()),
|
|
get<2>(TileShape{})));
|
|
|
|
// Define A and B block shapes for reduced size TMA_LOADs
|
|
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
|
|
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
|
|
|
|
using ElementPairA = ElementPairA_;
|
|
using ElementPairB = ElementPairB_;
|
|
using ElementAMma = typename TiledMma::ValTypeA;
|
|
using ElementBMma = typename TiledMma::ValTypeB;
|
|
using StridePairA = StridePairA_;
|
|
using StridePairB = StridePairB_;
|
|
using SmemLayoutAtomPairA = SmemLayoutAtomPairA_;
|
|
using SmemLayoutAtomPairB = SmemLayoutAtomPairB_;
|
|
static_assert(cute::is_same_v<remove_cvref_t<decltype(get<1>(ElementPairA{}))>,
|
|
remove_cvref_t<decltype(get<1>(ElementPairB{}))>>, "SFA and SFB data types should be the same");
|
|
|
|
// A and B matrices
|
|
using ElementA = remove_cvref_t<decltype(get<0>(ElementPairA{}))>;
|
|
using StrideA = remove_cvref_t<decltype(get<0>(StridePairA{}))>;
|
|
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
|
|
|
using ElementB = remove_cvref_t<decltype(get<0>(ElementPairB{}))>;
|
|
using StrideB = remove_cvref_t<decltype(get<0>(StridePairB{}))>;
|
|
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
|
|
|
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
|
|
|
|
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
|
|
|
|
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
|
|
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
|
|
"ElementA and ElementB should be both runtime or both static.");
|
|
|
|
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
|
|
|
// SFA and SFB
|
|
using ElementSF = remove_cvref_t<decltype(get<1>(ElementPairA{}))>;
|
|
using LayoutSFA = remove_cvref_t<decltype(get<1>(StridePairA{}))>;
|
|
using InternalLayoutSFA = cute::remove_pointer_t<LayoutSFA>;
|
|
using LayoutSFB = remove_cvref_t<decltype(get<1>(StridePairB{}))>;
|
|
using InternalLayoutSFB = cute::remove_pointer_t<LayoutSFB>;
|
|
|
|
using ElementAccumulator = typename TiledMma::ValTypeC;
|
|
using GmemTiledCopyPairA = GmemTiledCopyPairA_;
|
|
using GmemTiledCopyPairB = GmemTiledCopyPairB_;
|
|
using GmemTiledCopyA = remove_cvref_t<decltype(get<0>(GmemTiledCopyPairA{}))>;
|
|
using GmemTiledCopySFA = remove_cvref_t<decltype(get<1>(GmemTiledCopyPairA{}))>;
|
|
using GmemTiledCopyB = remove_cvref_t<decltype(get<0>(GmemTiledCopyPairB{}))>;
|
|
using GmemTiledCopySFB = remove_cvref_t<decltype(get<1>(GmemTiledCopyPairB{}))>;
|
|
|
|
using SmemLayoutAtomA = remove_cvref_t<decltype(get<0>(SmemLayoutAtomPairA{}))>;
|
|
using SmemLayoutAtomSFA = remove_cvref_t<decltype(get<1>(SmemLayoutAtomPairA{}))>;
|
|
using SmemLayoutAtomB = remove_cvref_t<decltype(get<0>(SmemLayoutAtomPairB{}))>;
|
|
using SmemLayoutAtomSFB = remove_cvref_t<decltype(get<1>(SmemLayoutAtomPairB{}))>;
|
|
|
|
using SmemCopyAtomA = SmemCopyAtomA_;
|
|
using SmemCopyAtomB = SmemCopyAtomB_;
|
|
using TransformA = TransformA_;
|
|
using TransformB = TransformB_;
|
|
using ArchTag = typename DispatchPolicy::ArchTag;
|
|
|
|
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
|
|
DispatchPolicy::Stages,
|
|
ClusterShape,
|
|
AtomThrShapeMNK>;
|
|
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
|
|
|
|
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
|
|
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
|
|
"SmemLayoutAtom must evenly divide tile shape.");
|
|
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
|
|
"SmemLayoutAtom must evenly divide tile shape.");
|
|
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
|
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
|
|
|
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
|
|
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
|
|
"SmemLayoutAtom must evenly divide tile shape.");
|
|
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
|
|
"SmemLayoutAtom must evenly divide tile shape.");
|
|
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
|
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
|
|
|
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
|
|
// This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs.
|
|
// (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE)
|
|
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
|
|
SmemLayoutAtomA{},
|
|
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
|
|
cute::conditional_t<cutlass::gemm::detail::is_mn_major<InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
// (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE)
|
|
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
|
|
SmemLayoutAtomB{},
|
|
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
|
|
cute::conditional_t<cutlass::gemm::detail::is_mn_major<InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
|
|
// SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here.
|
|
// The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop
|
|
using SmemLayoutSFA = decltype(make_layout(
|
|
append(shape(SmemLayoutAtomSFA{}), Int<DispatchPolicy::Stages>{}),
|
|
append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{})))
|
|
));
|
|
using SmemLayoutSFB = decltype(make_layout(
|
|
append(shape(SmemLayoutAtomSFB{}), Int<DispatchPolicy::Stages>{}),
|
|
append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{})))
|
|
));
|
|
|
|
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
|
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
|
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
|
static_assert(
|
|
(size(AtomThrShapeMNK{}) == 1 &&
|
|
(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>)) ||
|
|
(size(AtomThrShapeMNK{}) == 2 &&
|
|
(cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyA, SM100_TMA_2SM_LOAD_MULTICAST>)),
|
|
"GmemTiledCopy - invalid TMA copy atom specified.");
|
|
static_assert(
|
|
(size(AtomThrShapeMNK{}) == 1 &&
|
|
(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>)) ||
|
|
(size(AtomThrShapeMNK{}) == 2 &&
|
|
(cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD> || cute::is_same_v<GmemTiledCopyB, SM100_TMA_2SM_LOAD_MULTICAST>)),
|
|
"GmemTiledCopy - invalid TMA copy atom specified.");
|
|
|
|
static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
|
|
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideA, StrideA>;
|
|
|
|
using TmaInternalElementA = cute::conditional_t<IsF8F6F4, ElementAMma, ElementA>;
|
|
using TmaInternalElementB = cute::conditional_t<IsF8F6F4, ElementBMma, ElementB>;
|
|
|
|
using SmemAllocTypeA = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
|
using SmemAllocTypeB = cute::conditional_t<IsF8F6F4 && cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
|
|
|
using BitTypeElementA = uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
|
using BitTypeElementB = uint_bit_t<cute::sizeof_bits_v<ElementB>>;
|
|
|
|
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
|
|
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
|
|
|
|
using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t<ElementAMma, IsRuntimeDataTypeA>::Type;
|
|
using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t<ElementBMma, IsRuntimeDataTypeB>::Type;
|
|
|
|
struct SharedStorage {
|
|
struct TensorStorage : cute::aligned_struct<128, _0> {
|
|
cute::ArrayEngine<SmemAllocTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
|
cute::ArrayEngine<SmemAllocTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
|
|
cute::ArrayEngine<ElementSF, cute::cosize_v<SmemLayoutSFA>> smem_SFA;
|
|
cute::ArrayEngine<ElementSF, cute::cosize_v<SmemLayoutSFB>> smem_SFB;
|
|
} tensors;
|
|
|
|
struct TensorMapStorage : cute::aligned_struct<128, _0> {
|
|
cute::TmaDescriptor smem_tensormap_A;
|
|
cute::TmaDescriptor smem_tensormap_B;
|
|
cute::TmaDescriptor smem_tensormap_SFA;
|
|
cute::TmaDescriptor smem_tensormap_SFB;
|
|
} tensormaps;
|
|
|
|
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
|
PipelineStorage pipeline;
|
|
};
|
|
|
|
// Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them.
|
|
using TensorStorage = typename SharedStorage::TensorStorage;
|
|
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
|
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
|
|
|
static constexpr uint32_t SFTransactionBytes =
|
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v<ElementSF>) +
|
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v<ElementSF>);
|
|
// Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly
|
|
static constexpr uint32_t ABTmaTransactionBytes =
|
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>) +
|
|
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v<ElementB>);
|
|
static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes;
|
|
|
|
// Host side kernel arguments
|
|
struct Arguments {
|
|
ArrayElementA const** ptr_A{nullptr};
|
|
StrideA dA{};
|
|
ArrayElementB const** ptr_B{nullptr};
|
|
StrideB dB{};
|
|
ElementSF const** ptr_SFA{nullptr};
|
|
LayoutSFA layout_SFA{};
|
|
ElementSF const** ptr_SFB{nullptr};
|
|
LayoutSFB layout_SFB{};
|
|
RuntimeDataTypeA runtime_data_type_a{};
|
|
RuntimeDataTypeB runtime_data_type_b{};
|
|
};
|
|
|
|
// Device side kernel params
|
|
struct Params {
|
|
using ClusterLayout_VMNK =
|
|
decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}),
|
|
ClusterShape{})), make_tile(typename TiledMma::AtomThrID{})));
|
|
using ClusterLayoutSfb_VMNK =
|
|
decltype(tiled_divide(make_layout(conditional_return<IsDynamicCluster>(make_shape(uint32_t(0), uint32_t(0), Int<1>{}),
|
|
ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{})));
|
|
|
|
using TMA_A = decltype(make_tma_atom_A_sm100<TmaInternalElementA>(
|
|
GmemTiledCopyA{},
|
|
make_tensor(recast_ptr<TmaInternalElementA>(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
|
|
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
ClusterLayout_VMNK{})
|
|
);
|
|
|
|
using TMA_B = decltype(make_tma_atom_B_sm100<TmaInternalElementB>(
|
|
GmemTiledCopyB{},
|
|
make_tensor(recast_ptr<TmaInternalElementB>(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
|
|
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
ClusterLayout_VMNK{})
|
|
);
|
|
|
|
using TMA_SFA = decltype(make_tma_atom_A_sm100<uint16_t>(
|
|
GmemTiledCopySFA{},
|
|
make_tensor(static_cast<ElementSF const*>(nullptr), InternalLayoutSFA{}),
|
|
SmemLayoutSFA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
ClusterLayout_VMNK{})
|
|
);
|
|
|
|
using TMA_SFB = decltype(make_tma_atom_B_sm100<uint16_t>(
|
|
GmemTiledCopySFB{},
|
|
make_tensor(static_cast<ElementSF const*>(nullptr), InternalLayoutSFB{}),
|
|
SmemLayoutSFB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape_SF{},
|
|
TiledMMA_SF{},
|
|
ClusterLayoutSfb_VMNK{})
|
|
);
|
|
|
|
TMA_A tma_load_a;
|
|
TMA_B tma_load_b;
|
|
TMA_SFA tma_load_sfa;
|
|
TMA_SFB tma_load_sfb;
|
|
TMA_A tma_load_a_fallback;
|
|
TMA_B tma_load_b_fallback;
|
|
TMA_SFA tma_load_sfa_fallback;
|
|
TMA_SFB tma_load_sfb_fallback;
|
|
dim3 cluster_shape_fallback;
|
|
RuntimeDataTypeA runtime_data_type_a;
|
|
RuntimeDataTypeB runtime_data_type_b;
|
|
cute::TmaDescriptor* tensormaps;
|
|
ArrayElementA const** ptr_A;
|
|
StrideA dA;
|
|
ArrayElementB const** ptr_B;
|
|
StrideB dB;
|
|
ElementSF const** ptr_SFA;
|
|
LayoutSFA layout_SFA;
|
|
ElementSF const** ptr_SFB;
|
|
LayoutSFB layout_SFB;
|
|
};
|
|
|
|
CUTLASS_DEVICE
|
|
CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster)
|
|
: cluster_shape_(cluster_shape)
|
|
, block_rank_in_cluster_(block_rank_in_cluster) {
|
|
if constexpr (IsDynamicCluster) {
|
|
const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x &&
|
|
cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y);
|
|
observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a;
|
|
observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b;
|
|
observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa;
|
|
observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb;
|
|
}
|
|
else {
|
|
observed_tma_load_a_ = ¶ms.tma_load_a;
|
|
observed_tma_load_b_ = ¶ms.tma_load_b;
|
|
observed_tma_load_sfa_ = ¶ms.tma_load_sfa;
|
|
observed_tma_load_sfb_ = ¶ms.tma_load_sfb;
|
|
}
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static constexpr Params
|
|
to_underlying_arguments(
|
|
ProblemShape problem_shapes,
|
|
Arguments const& args,
|
|
void* workspace,
|
|
cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
|
|
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
|
|
// These will be replaced with correct values before the initial tma load.
|
|
constexpr int tma_alignment_bits = 128;
|
|
auto init_M = tma_alignment_bits;
|
|
auto init_N = tma_alignment_bits;
|
|
auto init_K = tma_alignment_bits;
|
|
auto init_L = 1;
|
|
|
|
// Tensor pointers will be fixed before the first access
|
|
TmaInternalElementA const* ptr_A_first_batch = nullptr;
|
|
TmaInternalElementB const* ptr_B_first_batch = nullptr;
|
|
|
|
InternalStrideA stride_a;
|
|
InternalStrideB stride_b;
|
|
InternalLayoutSFA layout_SFA;
|
|
InternalLayoutSFB layout_SFB;
|
|
|
|
if constexpr (IsGroupedGemmKernel) {
|
|
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
|
|
stride_a = InternalStrideA{};
|
|
stride_b = InternalStrideB{};
|
|
layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
|
|
layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
|
|
}
|
|
else {
|
|
// Tensor shapes for Ptr-Array are initialized correctly only here.
|
|
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
|
|
init_M = get<0>(problem_shape_MNK);
|
|
init_N = get<1>(problem_shape_MNK);
|
|
init_K = get<2>(problem_shape_MNK);
|
|
|
|
stride_a = args.dA;
|
|
stride_b = args.dB;
|
|
layout_SFA = args.layout_SFA;
|
|
layout_SFB = args.layout_SFB;
|
|
}
|
|
|
|
// Batches/Groups are managed by using appropriate pointers to input matrices.
|
|
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a));
|
|
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b));
|
|
|
|
auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape);
|
|
// Cluster layout for TMA construction
|
|
auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{}));
|
|
auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback);
|
|
auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{}));
|
|
|
|
// Tensor pointers will be fixed before the first access
|
|
ElementSF const* ptr_SFA_first_batch = nullptr;
|
|
ElementSF const* ptr_SFB_first_batch = nullptr;
|
|
|
|
Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA);
|
|
Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB);
|
|
|
|
// Cluster layout for TMA construction of SFB
|
|
auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{}));
|
|
auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{}));
|
|
|
|
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
|
|
GmemTiledCopyA{},
|
|
tensor_a,
|
|
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk);
|
|
|
|
typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100<TmaInternalElementB>(
|
|
GmemTiledCopyB{},
|
|
tensor_b,
|
|
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk);
|
|
|
|
typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100<TmaInternalElementA>(
|
|
GmemTiledCopyA{},
|
|
tensor_a,
|
|
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk_fallback);
|
|
|
|
typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100<TmaInternalElementB>(
|
|
GmemTiledCopyB{},
|
|
tensor_b,
|
|
SmemLayoutB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk_fallback);
|
|
|
|
typename Params::TMA_SFA tma_load_sfa = make_tma_atom_A_sm100<uint16_t>(
|
|
GmemTiledCopySFA{},
|
|
tensor_sfa,
|
|
SmemLayoutSFA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk);
|
|
|
|
typename Params::TMA_SFB tma_load_sfb = make_tma_atom_B_sm100<uint16_t>(
|
|
GmemTiledCopySFB{},
|
|
tensor_sfb,
|
|
SmemLayoutSFB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape_SF{},
|
|
TiledMMA_SF{},
|
|
cluster_layout_sfb_vmnk);
|
|
|
|
typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom_A_sm100<uint16_t>(
|
|
GmemTiledCopySFA{},
|
|
tensor_sfa,
|
|
SmemLayoutSFA{}(_,_,_,cute::Int<0>{}),
|
|
TileShape{},
|
|
TiledMma{},
|
|
cluster_layout_vmnk_fallback);
|
|
|
|
typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom_B_sm100<uint16_t>(
|
|
GmemTiledCopySFB{},
|
|
tensor_sfb,
|
|
SmemLayoutSFB{}(_,_,_,cute::Int<0>{}),
|
|
TileShape_SF{},
|
|
TiledMMA_SF{},
|
|
cluster_layout_sfb_vmnk_fallback);
|
|
|
|
return {
|
|
tma_load_a,
|
|
tma_load_b,
|
|
tma_load_sfa,
|
|
tma_load_sfb,
|
|
tma_load_a_fallback,
|
|
tma_load_b_fallback,
|
|
tma_load_sfa_fallback,
|
|
tma_load_sfb_fallback,
|
|
hw_info.cluster_shape_fallback,
|
|
args.runtime_data_type_a,
|
|
args.runtime_data_type_b,
|
|
reinterpret_cast<cute::TmaDescriptor*>(workspace),
|
|
reinterpret_cast<ArrayElementA const**>(args.ptr_A),
|
|
args.dA,
|
|
reinterpret_cast<ArrayElementB const**>(args.ptr_B),
|
|
args.dB,
|
|
reinterpret_cast<ElementSF const**>(args.ptr_SFA),
|
|
args.layout_SFA,
|
|
reinterpret_cast<ElementSF const**>(args.ptr_SFB),
|
|
args.layout_SFB,
|
|
};
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static size_t
|
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
|
|
constexpr uint32_t NumInputTensors = 4;
|
|
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
|
|
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
|
|
return (NumInputTensors * SizeOfCuTensorMap * sm_count);
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static cutlass::Status
|
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) {
|
|
return cutlass::Status::kSuccess;
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static bool
|
|
can_implement(
|
|
ProblemShape problem_shapes,
|
|
[[maybe_unused]] Arguments const& args) {
|
|
constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
|
|
constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits<ElementB, IsF8F6F4>();
|
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits<ElementA>::value;
|
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits<ElementB>::value;
|
|
|
|
bool implementable = true;
|
|
if (problem_shapes.is_host_problem_shape_available()) {
|
|
// Check alignment for all problem sizes
|
|
for (int i = 0; i < problem_shapes.groups(); i++) {
|
|
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
|
auto [M,N,K,L] = problem_shape_MNKL;
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
|
|
}
|
|
}
|
|
|
|
if (!implementable) {
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
|
}
|
|
return implementable;
|
|
}
|
|
|
|
/// Construct A Single Stage's Accumulator Shape
|
|
CUTLASS_DEVICE auto
|
|
partition_accumulator_shape() {
|
|
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
|
|
|
return acc_shape;
|
|
}
|
|
|
|
|
|
template <class FrgEngine, class FrgLayout>
|
|
CUTLASS_DEVICE auto
|
|
slice_accumulator(cute::Tensor<FrgEngine, FrgLayout> const& accumulators, int stage) {
|
|
return accumulators(_,_,_,stage);
|
|
}
|
|
|
|
/// Set up the data needed by this collective for load.
|
|
/// Return tuple element contain
|
|
/// gA_mkl - The tiled tma tensor for input A
|
|
/// gB_nkl - The tiled tma tensor for input B
|
|
/// tAgA_mkl - partitioned gmem tensor for A
|
|
/// tBgB_nkl - partitioned gmem tensor for B
|
|
/// tAsA - partitioned smem tensor for A
|
|
/// tBsB - partitioned smem tensor for B
|
|
/// tAgSFA_mkl - partitioned gmem tensor for SFA
|
|
/// tBgSFB_nkl - partitioned gmem tensor for SFB
|
|
/// tAsSFA - partitioned tmem tensor for SFA
|
|
/// tAsSFB - partitioned tmem tensor for SFB
|
|
/// mcast_mask_a - tma multicast mask for A
|
|
/// mcast_mask_b - tma multicast mask for B
|
|
/// mcast_mask_sfa - tma multicast mask for SFA
|
|
/// mcast_mask_sfb - tma multicast mask for SFB
|
|
template <class ProblemShape_MNKL>
|
|
CUTLASS_DEVICE auto
|
|
load_init(
|
|
ProblemShape_MNKL const& problem_shape_MNKL,
|
|
Params const& params,
|
|
TensorStorage& shared_tensors,
|
|
TensorMapStorage& shared_tensormaps,
|
|
int32_t const sm_count, int32_t const sm_idx,
|
|
int32_t init_group) const {
|
|
using X = Underscore;
|
|
|
|
// Separate out problem shape for convenience
|
|
auto [M,N,K,L] = problem_shape_MNKL;
|
|
// Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads
|
|
// we are managing TMA descriptors to change batches, we need to neglect the L mode
|
|
const int32_t mock_L = 1;
|
|
|
|
// Represent the full tensors -- get these from TMA
|
|
Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,mock_L));
|
|
Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,mock_L));
|
|
|
|
// Tile the tensors and defer the slice
|
|
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l)
|
|
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l)
|
|
|
|
// Represent the full tensor of Scale factors
|
|
InternalLayoutSFA layout_SFA{};
|
|
InternalLayoutSFB layout_SFB{};
|
|
if constexpr (IsGroupedGemmKernel) {
|
|
layout_SFA = params.layout_SFA[init_group];
|
|
layout_SFB = params.layout_SFB[init_group];
|
|
}
|
|
else {
|
|
layout_SFA = params.layout_SFA;
|
|
layout_SFB = params.layout_SFB;
|
|
}
|
|
Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA));
|
|
auto mSFB_nkl = [=](){
|
|
if constexpr (IsCtaN192) {
|
|
Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB));
|
|
auto x = stride<0,1>(mSFB_tmp);
|
|
auto y = ceil_div(shape<0,1>(mSFB_tmp), 4);
|
|
auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp),
|
|
make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp));
|
|
auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp),
|
|
make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp));
|
|
return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride));
|
|
}
|
|
else {
|
|
return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB));
|
|
}
|
|
}();
|
|
|
|
Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l)
|
|
Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l)
|
|
|
|
|
|
// Partition for this CTA
|
|
ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{}));
|
|
|
|
Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
|
|
Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l)
|
|
|
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
|
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE)
|
|
|
|
ThrMMA cta_mma_sfb = TiledMMA_SF{}.get_slice(blockIdx.x % size(typename TiledMMA_SF::AtomThrID{}));
|
|
Tensor tCgSFA_mkl = cta_mma.partition_A(gSFA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
|
|
Tensor tCgSFB_nkl = cta_mma_sfb.partition_B(gSFB_nkl); // (MMA, MMA_N, MMA_K, n, k, l)
|
|
|
|
Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{});
|
|
Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
|
|
|
|
// Define the CTA-in-Cluster Layout and Coord
|
|
Layout cta_layout_mnk = make_layout(cluster_shape_);
|
|
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
|
|
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_);
|
|
|
|
Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{}));
|
|
auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster_);
|
|
|
|
// Project the cta_layout for tma_a along the n-modes
|
|
auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_,
|
|
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
|
|
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl));
|
|
|
|
// Project the cta_layout for tma_b along the m-modes
|
|
auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_,
|
|
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
|
|
group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl));
|
|
|
|
// Project the cta_layout for tma_a along the n-modes
|
|
auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_,
|
|
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
|
|
group_modes<0,3>(sSFA), group_modes<0,3>(tCgSFA_mkl));
|
|
|
|
// Project the cta_layout for tma_b along the m-modes
|
|
auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_,
|
|
get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)),
|
|
group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl));
|
|
|
|
// TMA Multicast Masks
|
|
uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
|
|
uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
|
|
uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
|
|
uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk);
|
|
|
|
// Fetch a copy of tensormaps for the CTA from Params
|
|
auto input_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx);
|
|
|
|
return cute::make_tuple(
|
|
gA_mkl, gB_nkl, // for scheduler
|
|
tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values
|
|
tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values
|
|
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb, // multicast masks
|
|
input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy)
|
|
}
|
|
|
|
/// Set up the data needed by this collective for mma compute.
|
|
template <class FrgEngine, class FrgLayout>
|
|
CUTLASS_DEVICE auto
|
|
mma_init(
|
|
Params const& params,
|
|
[[maybe_unused]] cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
|
|
TensorStorage& shared_tensors,
|
|
uint32_t const tmem_offset) const {
|
|
|
|
// Allocate "fragments/descriptors" for A and B matrices
|
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
|
|
|
// Allocate "fragments/descriptors" for A and B matrices
|
|
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
|
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
|
|
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
|
|
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB)); // PIPE
|
|
|
|
//
|
|
// Scale Factor
|
|
//
|
|
Tensor tCtSFA = make_tensor<typename TiledMma::FrgTypeSFA>(shape(SmemLayoutAtomSFA{}));
|
|
// Set tCtSFA and tCtSFB start addresses. Only update the TMEM column address by masking the address with 0x000001FF.
|
|
// TMEM allocations for SFA and SFB will always start at DP 0.
|
|
tCtSFA.data() = tmem_offset;
|
|
Tensor tCtSFB = make_tensor<typename TiledMma::FrgTypeSFB>(shape(SmemLayoutAtomSFB{}));
|
|
tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA);
|
|
|
|
// Setup smem descriptors for UTCCP
|
|
Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{});
|
|
Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{});
|
|
|
|
// Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions.
|
|
auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout()));
|
|
auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout()));
|
|
auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout()));
|
|
auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout()));
|
|
|
|
// Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA)
|
|
using AtomThrID = typename TiledMma::AtomThrID;
|
|
using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value),
|
|
SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>;
|
|
auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact);
|
|
auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact);
|
|
|
|
auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0);
|
|
auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact);
|
|
// SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor
|
|
auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor<UtccpOp>(thr_tCsSFA_compact_s2t_);
|
|
auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact);
|
|
|
|
auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0);
|
|
auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact);
|
|
// SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor
|
|
auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor<UtccpOp>(thr_tCsSFB_compact_s2t_);
|
|
auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact);
|
|
|
|
TiledMma tiled_mma;
|
|
|
|
if constexpr (IsRuntimeDataType) {
|
|
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111;
|
|
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111;
|
|
}
|
|
|
|
return cute::make_tuple(
|
|
tiled_mma,
|
|
tCrA, tCrB, tCtSFA, tCtSFB,
|
|
tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t,
|
|
tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t);
|
|
}
|
|
|
|
/// Perform a collective-scoped matrix multiply-accumulate
|
|
/// Producer Perspective
|
|
template <
|
|
class GTensorA, class GTensorB,
|
|
class GTensorPartitionedA, class GTensorPartitionedB,
|
|
class STensorA, class STensorB,
|
|
class GTensorPartitionedSFA, class GTensorPartitionedSFB,
|
|
class STensorSFA, class STensorSFB,
|
|
class TensorMapA, class TensorMapB,
|
|
class TensorMapSFA, class TensorMapSFB,
|
|
class TileCoordMNKL,
|
|
class KTileIterator
|
|
>
|
|
CUTLASS_DEVICE auto
|
|
load(
|
|
Params const& params,
|
|
MainloopPipeline mainloop_pipeline,
|
|
MainloopPipelineState mainloop_pipe_producer_state,
|
|
cute::tuple<GTensorA, GTensorB,
|
|
GTensorPartitionedA, GTensorPartitionedB,
|
|
STensorA, STensorB,
|
|
GTensorPartitionedSFA, GTensorPartitionedSFB,
|
|
STensorSFA, STensorSFB,
|
|
uint16_t, uint16_t,
|
|
uint16_t, uint16_t,
|
|
cute::tuple<TensorMapA, TensorMapB, TensorMapSFA, TensorMapSFB>> const& load_inputs,
|
|
TileCoordMNKL const& cta_coord_mnkl,
|
|
KTileIterator k_tile_iter, int k_tile_count,
|
|
bool did_batch_change) {
|
|
|
|
auto [unused_gA, unused_gB,
|
|
tAgA_mkl, tBgB_nkl, tAsA, tBsB,
|
|
tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB,
|
|
mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb,
|
|
input_tensormaps] = load_inputs;
|
|
|
|
// Check to see if tensormaps have been replaced in gmem
|
|
if (did_batch_change) {
|
|
tensormaps_fence_acquire(input_tensormaps);
|
|
}
|
|
|
|
// slice out the work coord from partitioned tensors
|
|
Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
|
|
Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
|
|
Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
|
|
Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
|
|
|
|
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
|
|
|
// Issue the Mainloop loads
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
while (k_tile_count > 0) {
|
|
// LOCK mainloop_pipe_producer_state for _writing_
|
|
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
|
|
// Note: We don't synchronize the sf_pipeline for "Buffer_Empty". We use mainloop pipeline
|
|
// to do the synchronization at once.
|
|
|
|
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
|
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
|
|
|
|
int write_stage = mainloop_pipe_producer_state.index();
|
|
++mainloop_pipe_producer_state;
|
|
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
|
|
|
if (cute::elect_one_sync()) {
|
|
copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
|
|
copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage));
|
|
copy(observed_tma_load_sfa_->with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage));
|
|
copy(observed_tma_load_sfb_->with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage));
|
|
}
|
|
|
|
--k_tile_count;
|
|
++k_tile_iter;
|
|
}
|
|
|
|
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
|
|
}
|
|
|
|
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
|
|
CUTLASS_DEVICE void
|
|
load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) {
|
|
// Issue the epilogue waits
|
|
// This helps avoid early exit of ctas in Cluster
|
|
// Waits for all stages to either be released (all
|
|
// Consumer UNLOCKs), or if the stage was never used
|
|
// then would just be acquired since the phase was
|
|
// still inverted from make_producer_start_state
|
|
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
|
|
}
|
|
|
|
/// Perform a collective-scoped matrix multiply-accumulate
|
|
/// Consumer Perspective
|
|
template <
|
|
class AccumulatorPipeline,
|
|
class FrgEngine, class FrgLayout,
|
|
class FragmentA, class FragmentB,
|
|
class FragmentSFA, class FragmentSFB,
|
|
class CtaTileCoord,
|
|
class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA,
|
|
class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB
|
|
>
|
|
CUTLASS_DEVICE auto
|
|
mma(cute::tuple<MainloopPipeline,
|
|
AccumulatorPipeline> pipelines,
|
|
cute::tuple<MainloopPipelineState,
|
|
typename AccumulatorPipeline::PipelineState> pipeline_states,
|
|
cute::Tensor<FrgEngine, FrgLayout>& accumulators,
|
|
cute::tuple<TiledMma,
|
|
FragmentA, FragmentB,
|
|
FragmentSFA, FragmentSFB,
|
|
SFATiledCopy, SmemFrgSFA, TmemFrgSFA,
|
|
SFBTiledCopy, SmemFrgSFB, TmemFrgSFB> const& mma_inputs,
|
|
CtaTileCoord cta_tile_coord,
|
|
int k_tile_count
|
|
) {
|
|
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
|
|
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
|
|
|
|
auto [tiled_mma,
|
|
tCrA, tCrB, tCtSFA, tCtSFB,
|
|
tiled_copy_s2t_SFA, thr_tCsSFA_s2t,
|
|
thr_tCtSFA_s2t, tiled_copy_s2t_SFB,
|
|
thr_tCsSFB_s2t, thr_tCtSFB_s2t] = mma_inputs;
|
|
|
|
auto [mainloop_pipeline, accumulator_pipeline] = pipelines;
|
|
auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states;
|
|
|
|
auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() {
|
|
if constexpr (IsCtaN192) {
|
|
// If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB)
|
|
auto tCtSFB_tmp = tCtSFB;
|
|
if (get<1>(cta_tile_coord) % 2 == 1) {
|
|
tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2;
|
|
}
|
|
return tCtSFB_tmp;
|
|
}
|
|
else {
|
|
return tCtSFB;
|
|
}
|
|
}();
|
|
|
|
uint32_t skip_wait = k_tile_count <= 0;
|
|
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
|
|
|
//
|
|
// PIPELINED MAIN LOOP
|
|
//
|
|
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
|
|
|
if (k_tile_count > 0) { // first iteraion
|
|
// WAIT on mainloop_pipe_consumer_state until its data are available
|
|
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
|
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
|
|
|
// Compute on k_tile
|
|
int read_stage = mainloop_pipe_consumer_state.index();
|
|
// Save current mainlop pipeline read state
|
|
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
|
|
|
// Advance mainloop_pipe
|
|
++mainloop_pipe_consumer_state;
|
|
--k_tile_count;
|
|
skip_wait = k_tile_count <= 0;
|
|
// Peek at next iteration
|
|
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
|
|
|
if (cute::elect_one_sync()) {
|
|
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
|
|
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
|
}
|
|
|
|
if constexpr (IsOverlappingAccum) {
|
|
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
|
}
|
|
|
|
// Unroll the K mode manually so we can set scale C to 1
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
|
|
tCtSFA(_,_,k_block),
|
|
tCtSFB_mma(_,_,k_block)),
|
|
tCrA(_,_,k_block,read_stage),
|
|
tCrB(_,_,k_block,read_stage),
|
|
accumulators);
|
|
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
|
}
|
|
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
|
}
|
|
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
while (k_tile_count > 0) {
|
|
// WAIT on mainloop_pipe_consumer_state until its data are available
|
|
// (phase bit flips from mainloop_pipe_consumer_state.phase() value)
|
|
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
|
|
|
|
// Compute on k_tile
|
|
int read_stage = mainloop_pipe_consumer_state.index();
|
|
// Save current mainlop pipeline read state
|
|
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
|
|
|
|
// Advance mainloop_pipe
|
|
++mainloop_pipe_consumer_state;
|
|
--k_tile_count;
|
|
skip_wait = k_tile_count <= 0;
|
|
// Peek at next iteration
|
|
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
|
|
|
|
if (cute::elect_one_sync()) {
|
|
copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t);
|
|
copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t);
|
|
}
|
|
|
|
// Unroll the K mode manually so we can set scale C to 1
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
cute::gemm(tiled_mma.with(tiled_mma.accumulate_,
|
|
tCtSFA(_,_,k_block),
|
|
tCtSFB_mma(_,_,k_block)),
|
|
tCrA(_,_,k_block,read_stage),
|
|
tCrB(_,_,k_block,read_stage),
|
|
accumulators);
|
|
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
|
}
|
|
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
|
|
}
|
|
|
|
return mainloop_pipe_consumer_state;
|
|
}
|
|
|
|
//
|
|
// Methods to perform different parts of TMA/Tensormap modifications
|
|
//
|
|
|
|
CUTLASS_DEVICE auto
|
|
tensormaps_init(
|
|
Params const& mainloop_params,
|
|
TensorMapStorage& shared_tensormaps,
|
|
int32_t const sm_count,
|
|
int32_t const sm_idx) const {
|
|
cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps;
|
|
|
|
cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx];
|
|
cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count];
|
|
|
|
cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count];
|
|
cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count];
|
|
|
|
if (cute::elect_one_sync()) {
|
|
// Bringing tensormaps from params to smem for modification later
|
|
Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{});
|
|
Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{});
|
|
|
|
Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{});
|
|
Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{});
|
|
|
|
copy(recast<uint128_t>(pA_tensormap), recast<uint128_t>(sA_tensormap));
|
|
copy(recast<uint128_t>(pB_tensormap), recast<uint128_t>(sB_tensormap));
|
|
|
|
copy(recast<uint128_t>(pSFA_tensormap), recast<uint128_t>(sSFA_tensormap));
|
|
copy(recast<uint128_t>(pSFB_tensormap), recast<uint128_t>(sSFB_tensormap));
|
|
}
|
|
__syncwarp();
|
|
|
|
return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_sfa, tma_desc_sfb);
|
|
}
|
|
|
|
// Replace address for the global tensor (to be done by single thread)
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_replace_global_address(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& mainloop_params,
|
|
int32_t next_batch) {
|
|
// Replacing global_address for the next batch
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
|
mainloop_params.ptr_A[next_batch]);
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
|
mainloop_params.ptr_B[next_batch]);
|
|
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA,
|
|
mainloop_params.ptr_SFA[next_batch]);
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB,
|
|
mainloop_params.ptr_SFB[next_batch]);
|
|
}
|
|
|
|
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
|
|
template <class ProblemShape_MNKL>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_replace_global_tensor_properties(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& mainloop_params,
|
|
int32_t next_group,
|
|
ProblemShape_MNKL problem_shape_mnkl) {
|
|
const uint32_t M = get<0>(problem_shape_mnkl);
|
|
const uint32_t N = get<1>(problem_shape_mnkl);
|
|
const uint32_t K = get<2>(problem_shape_mnkl);
|
|
// Replace all dims for consistency
|
|
constexpr int MaxTensorRank = 5;
|
|
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
|
|
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
|
|
cute::array<uint32_t, MaxTensorRank> prob_shape_SFA = {1,1,1,1,1};
|
|
cute::array<uint64_t, MaxTensorRank> prob_stride_SFA = {0,0,0,0,0};
|
|
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
|
|
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
|
|
cute::array<uint32_t, MaxTensorRank> prob_shape_SFB = {1,1,1,1,1};
|
|
cute::array<uint64_t, MaxTensorRank> prob_stride_SFB = {0,0,0,0,0};
|
|
|
|
TmaInternalElementA const* ptr_A = nullptr;
|
|
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
|
|
|
|
ElementSF const* ptr_SF = nullptr;
|
|
Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]);
|
|
|
|
TmaInternalElementB const* ptr_B = nullptr;
|
|
Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]);
|
|
|
|
Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]);
|
|
|
|
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a,
|
|
prob_shape_A, prob_stride_A);
|
|
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa,
|
|
prob_shape_SFA, prob_stride_SFA);
|
|
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b,
|
|
prob_shape_B, prob_stride_B);
|
|
cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb,
|
|
prob_shape_SFB, prob_stride_SFB);
|
|
|
|
// Convert strides to byte strides
|
|
for (uint64_t& stride : prob_stride_A) {
|
|
stride = (stride * sizeof_bits_v<TmaInternalElementA>) / 8;
|
|
}
|
|
for (uint64_t& stride : prob_stride_SFA) {
|
|
stride = (stride * sizeof_bits_v<ElementSF>) / 8;
|
|
}
|
|
for (uint64_t& stride : prob_stride_B) {
|
|
stride = (stride * sizeof_bits_v<TmaInternalElementB>) / 8;
|
|
}
|
|
for (uint64_t& stride : prob_stride_SFB) {
|
|
stride = (stride * sizeof_bits_v<ElementSF>) / 8;
|
|
}
|
|
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A,
|
|
prob_shape_A,
|
|
prob_stride_A);
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA,
|
|
prob_shape_SFA,
|
|
prob_stride_SFA);
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B,
|
|
prob_shape_B,
|
|
prob_stride_B);
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB,
|
|
prob_shape_SFB,
|
|
prob_stride_SFB);
|
|
}
|
|
|
|
// The entire warp must call this function collectively (that is, the instructions are aligned)
|
|
template <class TensorMapA, class TensorMapB, class TensorMapSFA, class TensorMapSFB, class ProblemShape>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_perform_update(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& mainloop_params,
|
|
cute::tuple<TensorMapA, TensorMapB, TensorMapSFA, TensorMapSFB> const& input_tensormaps,
|
|
ProblemShape problem_shape,
|
|
int32_t next_batch) {
|
|
if (cute::elect_one_sync()) {
|
|
// Replacing global_address for the next batch
|
|
tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch);
|
|
|
|
if constexpr (IsGroupedGemmKernel) {
|
|
auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1);
|
|
// Replacing global dims and strides for the next batch
|
|
tensormaps_replace_global_tensor_properties(shared_tensormaps,
|
|
mainloop_params, next_batch, problem_shape_MNKL);
|
|
}
|
|
}
|
|
// Ensure warp is converged before issuing tensormap fence release
|
|
__syncwarp();
|
|
// Entire warp must do this (ie its aligned)
|
|
tensormaps_cp_fence_release(shared_tensormaps, input_tensormaps);
|
|
}
|
|
|
|
template <class TensorMapA, class TensorMapB, class TensorMapSFA, class TensorMapSFB>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_cp_fence_release (
|
|
TensorMapStorage& shared_tensormaps,
|
|
cute::tuple<TensorMapA, TensorMapB, TensorMapSFA, TensorMapSFB> const& input_tensormaps) {
|
|
if (cute::elect_one_sync()) {
|
|
cute::tma_desc_commit_group();
|
|
cute::tma_desc_wait_group();
|
|
}
|
|
// Entire warp must do this (i.e. it's aligned)
|
|
tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A);
|
|
tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B);
|
|
|
|
tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_SFA);
|
|
tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_SFB);
|
|
}
|
|
|
|
// The entire warp must call this function collectively (that is, the instructions are aligned)
|
|
template <class TensorMapA, class TensorMapB, class TensorMapSFA, class TensorMapSFB>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_fence_acquire(cute::tuple<TensorMapA, TensorMapB, TensorMapSFA, TensorMapSFB> const& input_tensormaps) {
|
|
cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps));
|
|
cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps));
|
|
cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps));
|
|
cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps));
|
|
}
|
|
|
|
private:
|
|
typename Params::TMA_A const* observed_tma_load_a_{nullptr};
|
|
typename Params::TMA_B const* observed_tma_load_b_{nullptr};
|
|
typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr};
|
|
typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr};
|
|
|
|
ClusterShape cluster_shape_;
|
|
uint32_t block_rank_in_cluster_;
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass::gemm::collective
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|