diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 2161910659..da8b84107e 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -124,11 +124,11 @@ struct GemmMXKernel using BlockGemm = remove_cvref_t; using MThreadPerXdl = BlockGemm::WarpGemm::kM; using NThreadPerXdl = BlockGemm::WarpGemm::kN; - using KThreadPerXdl = 64 / MThreadPerXdl; // 64 is the number of threads in a wave + using KThreadPerXdl = get_warp_size() / MThreadPerXdl; // 64 is the number of threads in a wave - static constexpr auto MXdlPack = 2; - static constexpr auto NXdlPack = 2; - static constexpr auto KXdlPack = 2; + using MXdlPack = remove_cvref_t; + using NXdlPack = remove_cvref_t; + using KXdlPack = remove_cvref_t; using mx_scale_t = ck_tile::e8m0_bexp_t; static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); @@ -435,16 +435,6 @@ struct GemmMXKernel return make_tensor_view(a_scale_ptr, a_m_k_desc); }(); - // const auto& aq_tensor_view = [&]() { - // static_assert(std::is_same_v); - // return make_naive_tensor_view( - // aq_ptr, - // make_tuple(kargs.M, kargs.QK), - // make_tuple(kargs.stride_AQ, 1), - // number{}, - // number<1>{}); - // }(); - const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp index a3dd430faa..316e52b860 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp @@ -19,34 +19,60 @@ struct GemmMXPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase; + using AScaleLayout = remove_cvref_t; + using BScaleLayout = remove_cvref_t; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; - static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize; + static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize; - static_assert(KPerBlock % QuantGroupSize == 0, - "KPerBlock must be a multiple of QuantGroupSize"); + static_assert(KPerBlock * % BlockScaleSize == 0, + "KPerBlock must be a multiple of BlockScaleSize"); - // Create DRAM tile window for AQ - template + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + // Create DRAM tile window for A scale + template CK_TILE_DEVICE constexpr auto - GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const + GetAScaleDramLoadWindow(const AScaleDramBlockWindowTmp& a_scale_dram_block_window_tmp) const { - static_assert(std::is_same_v); + static_assert( + std::is_same_v); + using YPerTile = number; + using XPerTile = number; - using YPerTile = number; - using XPerTile = number; - - auto aq_copy_dram_window = - make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + auto a_copy_draw_window = + make_tile_window(a_scale_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(YPerTile(), XPerTile()), - aq_dram_block_window_tmp.get_window_origin(), - Policy::template MakeAQDramTileDistribution()); - return aq_copy_dram_window; + a_scale_dram_block_window_tmp.get_window_origin(), + Policy::template MakeAScaleDramTileDistribution()); + return a_copy_draw_window; + } + + // Create DRAM tile window for B scale + template + CK_TILE_DEVICE constexpr auto + GetBScaleDramLoadWindow(const BScaleDramBlockWindowTmp& b_scale_dram_block_window_tmp) const + { + static_assert( + std::is_same_v); + using YPerTile = number; + using XPerTile = number; + + auto b_copy_draw_window = + make_tile_window(b_scale_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + b_scale_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBScaleDramTileDistribution()); } }; diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index c63e706203..36cb705481 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -75,6 +75,8 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol return TileEncodingPattern::Make2DStaticTileDistribution(); } + // A Scale DRAM tile distribution + // This is used to load the A scale data from DRAM into shared memory. template CK_TILE_HOST_DEVICE static constexpr auto MakeAScaleDramTileDistribution() { @@ -105,12 +107,17 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol MPerBlock, KPerBlockScale, MXdlPack, - NXdlPack, - KXdlPack, - VecLoadSize>; + KXdlPack>; return TileEncodingPattern::Make2DStaticTileDistribution(); } + // B Scale DRAM tile distribution + // This is used to load the B scale data from DRAM into shared memory. + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBScaleDramTileDistribution() + { + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index a84c279792..1952f867b1 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -97,15 +97,15 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3; using I2 = number<2>; - static constexpr index_t APackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; + // static constexpr index_t APackedSize = + // ck_tile::numeric_traits>::PackedSize; + // static constexpr index_t BPackedSize = + // ck_tile::numeric_traits>::PackedSize; - static constexpr index_t AScalePackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BScalePackedSize = - ck_tile::numeric_traits>::PackedSize; + // static constexpr index_t AScalePackedSize = + // ck_tile::numeric_traits>::PackedSize; + // static constexpr index_t BScalePackedSize = + // ck_tile::numeric_traits>::PackedSize; using ALayout = remove_cvref_t; using AScaleLayout = remove_cvref_t; @@ -122,6 +122,13 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } @@ -312,7 +319,10 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3 + index_t VecSize = 1> struct TileDistributionEncodingPatternAScale : public TileDistributionEncodingPattern { -} + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / warp_size; + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MThreadPerXdl = WarpGemm::kM; + static constexpr index_t KThreadPerXdl = warp_size / MThreadPerXdl; + + static_assert(num_warps == MWarps * NWarps * KWarps, "Block warps do not match block size"); + static_assert(KWarps == 1, "KWarps > 1 is not supported"); + + // Y dimension (M) decomposition + static constexpr index_t MXdlPack = 2; // MXdlPack is always 2 + static constexpr index_t Y1 = MWarps; + static constexpr index_t Y2 = MThreadPerXdl; + static constexpr index_t Y0 = YPerTile / MXdlPack / (MWarps * MThreadPerXdl); + + // X dimension (K) decomposition + static constexpr index_t X0 = KThreadPerXdl; + static constexpr index_t X1 = VecSize; + + static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, "Y dimensions must cover the YPerTile"); + static_assert(X0 * X1 == XPerTile, "X dimensions must cover the XPerTile"); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; } // namespace ck_tile