update A, B scale distribution patterns

This commit is contained in:
mtgu0705
2025-08-13 00:59:53 -05:00
parent 093f9c4d20
commit 5e5b9bdbd0
5 changed files with 115 additions and 45 deletions

View File

@@ -124,11 +124,11 @@ struct GemmMXKernel
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
using KThreadPerXdl = 64 / MThreadPerXdl; // 64 is the number of threads in a wave
using KThreadPerXdl = get_warp_size() / MThreadPerXdl; // 64 is the number of threads in a wave
static constexpr auto MXdlPack = 2;
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
using MXdlPack = remove_cvref_t<typename GemmPipeline::MXdlPack>;
using NXdlPack = remove_cvref_t<typename GemmPipeline::NXdlPack>;
using KXdlPack = remove_cvref_t<typename GemmPipeline::KXdlPack>;
using mx_scale_t = ck_tile::e8m0_bexp_t;
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
@@ -435,16 +435,6 @@ struct GemmMXKernel
return make_tensor_view<address_space_enum::global>(a_scale_ptr, a_m_k_desc);
}();
// const auto& aq_tensor_view = [&]() {
// static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
// return make_naive_tensor_view<address_space_enum::global>(
// aq_ptr,
// make_tuple(kargs.M, kargs.QK),
// make_tuple(kargs.stride_AQ, 1),
// number<GemmPipeline::GetVectorSizeAQ()>{},
// number<1>{});
// }();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{

View File

@@ -19,34 +19,60 @@ struct GemmMXPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem,
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
using BScaleLayout = remove_cvref_t<typename Problem::BScaleLayout>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
static_assert(KPerBlock % QuantGroupSize == 0,
"KPerBlock must be a multiple of QuantGroupSize");
static_assert(KPerBlock * % BlockScaleSize == 0,
"KPerBlock must be a multiple of BlockScaleSize");
// Create DRAM tile window for AQ
template <typename AQDramBlockWindowTmp>
static constexpr auto MXdlPack = 2;
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
// Create DRAM tile window for A scale
template <typename AScaleDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
GetAScaleDramLoadWindow(const AScaleDramBlockWindowTmp& a_scale_dram_block_window_tmp) const
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
static_assert(
std::is_same_v<typename Problem::AScaleLayout, tensor_layout::gemm::RowMajor>);
using YPerTile = number<MPerBlock / MXdlPack>;
using XPerTile = number<KPerBlock / BlockScaleSize / KXdlPack>;
using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
auto a_copy_draw_window =
make_tile_window(a_scale_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
aq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAQDramTileDistribution<Problem>());
return aq_copy_dram_window;
a_scale_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAScaleDramTileDistribution<Problem>());
return a_copy_draw_window;
}
// Create DRAM tile window for B scale
template <typename BScaleDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetBScaleDramLoadWindow(const BScaleDramBlockWindowTmp& b_scale_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<typename Problem::BScaleLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlock / NXdlPack>;
using XPerTile = number<KPerBlock / BlockScaleSize / KXdlPack>;
auto b_copy_draw_window =
make_tile_window(b_scale_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
b_scale_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBScaleDramTileDistribution<Problem>());
}
};

View File

@@ -75,6 +75,8 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// A Scale DRAM tile distribution
// This is used to load the A scale data from DRAM into shared memory.
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeAScaleDramTileDistribution()
{
@@ -105,12 +107,17 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
MPerBlock,
KPerBlockScale,
MXdlPack,
NXdlPack,
KXdlPack,
VecLoadSize>;
KXdlPack>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
// B Scale DRAM tile distribution
// This is used to load the B scale data from DRAM into shared memory.
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBScaleDramTileDistribution()
{
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{

View File

@@ -97,15 +97,15 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// static constexpr index_t APackedSize =
// ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// static constexpr index_t BPackedSize =
// ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t AScalePackedSize =
ck_tile::numeric_traits<remove_cvref_t<AScaleDataType>>::PackedSize;
static constexpr index_t BScalePackedSize =
ck_tile::numeric_traits<remove_cvref_t<BScaleDataType>>::PackedSize;
// static constexpr index_t AScalePackedSize =
// ck_tile::numeric_traits<remove_cvref_t<AScaleDataType>>::PackedSize;
// static constexpr index_t BScalePackedSize =
// ck_tile::numeric_traits<remove_cvref_t<BScaleDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
@@ -122,6 +122,13 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize;
using MXdlPack = PipelineImplBase::MXdlPack;
using NXdlPack = PipelineImplBase::NXdlPack;
using KXdlPack = PipelineImplBase::KXdlPack;
using APackedSize = PipelineImplBase::APackedSize;
using BPackedSize = PipelineImplBase::BPackedSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
@@ -312,7 +319,10 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
auto a_scale_copy_dram_window =
Base::GetAScaleDramLoadWindow(a_scale_dram_block_window_tmp);
auto b_scale_copy_dram_window =
Base::GetBScaleDramLoadWindow(b_scale_dram_block_window_tmp);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());

View File

@@ -102,9 +102,46 @@ template <typename BlockGemmShape,
index_t MXdlPack,
index_t NXdlPack,
index_t KXdlPack,
index_t VecSize>
index_t VecSize = 1>
struct TileDistributionEncodingPatternAScale : public TileDistributionEncodingPattern
{
}
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / warp_size;
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr index_t MThreadPerXdl = WarpGemm::kM;
static constexpr index_t KThreadPerXdl = warp_size / MThreadPerXdl;
static_assert(num_warps == MWarps * NWarps * KWarps, "Block warps do not match block size");
static_assert(KWarps == 1, "KWarps > 1 is not supported");
// Y dimension (M) decomposition
static constexpr index_t MXdlPack = 2; // MXdlPack is always 2
static constexpr index_t Y1 = MWarps;
static constexpr index_t Y2 = MThreadPerXdl;
static constexpr index_t Y0 = YPerTile / MXdlPack / (MWarps * MThreadPerXdl);
// X dimension (K) decomposition
static constexpr index_t X0 = KThreadPerXdl;
static constexpr index_t X1 = VecSize;
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, "Y dimensions must cover the YPerTile");
static_assert(X0 * X1 == XPerTile, "X dimensions must cover the XPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile