mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
update A, B scale distribution patterns
This commit is contained in:
@@ -124,11 +124,11 @@ struct GemmMXKernel
|
||||
using BlockGemm = remove_cvref_t<typename GemmPipeline::BlockGemm>;
|
||||
using MThreadPerXdl = BlockGemm::WarpGemm::kM;
|
||||
using NThreadPerXdl = BlockGemm::WarpGemm::kN;
|
||||
using KThreadPerXdl = 64 / MThreadPerXdl; // 64 is the number of threads in a wave
|
||||
using KThreadPerXdl = get_warp_size() / MThreadPerXdl; // 64 is the number of threads in a wave
|
||||
|
||||
static constexpr auto MXdlPack = 2;
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
using MXdlPack = remove_cvref_t<typename GemmPipeline::MXdlPack>;
|
||||
using NXdlPack = remove_cvref_t<typename GemmPipeline::NXdlPack>;
|
||||
using KXdlPack = remove_cvref_t<typename GemmPipeline::KXdlPack>;
|
||||
|
||||
using mx_scale_t = ck_tile::e8m0_bexp_t;
|
||||
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
@@ -435,16 +435,6 @@ struct GemmMXKernel
|
||||
return make_tensor_view<address_space_enum::global>(a_scale_ptr, a_m_k_desc);
|
||||
}();
|
||||
|
||||
// const auto& aq_tensor_view = [&]() {
|
||||
// static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
// return make_naive_tensor_view<address_space_enum::global>(
|
||||
// aq_ptr,
|
||||
// make_tuple(kargs.M, kargs.QK),
|
||||
// make_tuple(kargs.stride_AQ, 1),
|
||||
// number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
// number<1>{});
|
||||
// }();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
|
||||
@@ -19,34 +19,60 @@ struct GemmMXPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem,
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
|
||||
using BScaleLayout = remove_cvref_t<typename Problem::BScaleLayout>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
|
||||
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
static_assert(KPerBlock * % BlockScaleSize == 0,
|
||||
"KPerBlock must be a multiple of BlockScaleSize");
|
||||
|
||||
// Create DRAM tile window for AQ
|
||||
template <typename AQDramBlockWindowTmp>
|
||||
static constexpr auto MXdlPack = 2;
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
// Create DRAM tile window for A scale
|
||||
template <typename AScaleDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const
|
||||
GetAScaleDramLoadWindow(const AScaleDramBlockWindowTmp& a_scale_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
static_assert(
|
||||
std::is_same_v<typename Problem::AScaleLayout, tensor_layout::gemm::RowMajor>);
|
||||
using YPerTile = number<MPerBlock / MXdlPack>;
|
||||
using XPerTile = number<KPerBlock / BlockScaleSize / KXdlPack>;
|
||||
|
||||
using YPerTile = number<MPerBlock>;
|
||||
using XPerTile = number<KPerBlockAQ>;
|
||||
|
||||
auto aq_copy_dram_window =
|
||||
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
auto a_copy_draw_window =
|
||||
make_tile_window(a_scale_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
aq_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeAQDramTileDistribution<Problem>());
|
||||
return aq_copy_dram_window;
|
||||
a_scale_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeAScaleDramTileDistribution<Problem>());
|
||||
return a_copy_draw_window;
|
||||
}
|
||||
|
||||
// Create DRAM tile window for B scale
|
||||
template <typename BScaleDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetBScaleDramLoadWindow(const BScaleDramBlockWindowTmp& b_scale_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<typename Problem::BScaleLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using YPerTile = number<NPerBlock / NXdlPack>;
|
||||
using XPerTile = number<KPerBlock / BlockScaleSize / KXdlPack>;
|
||||
|
||||
auto b_copy_draw_window =
|
||||
make_tile_window(b_scale_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
b_scale_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBScaleDramTileDistribution<Problem>());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
// A Scale DRAM tile distribution
|
||||
// This is used to load the A scale data from DRAM into shared memory.
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeAScaleDramTileDistribution()
|
||||
{
|
||||
@@ -105,12 +107,17 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
|
||||
MPerBlock,
|
||||
KPerBlockScale,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
KXdlPack,
|
||||
VecLoadSize>;
|
||||
KXdlPack>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
// B Scale DRAM tile distribution
|
||||
// This is used to load the B scale data from DRAM into shared memory.
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBScaleDramTileDistribution()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
|
||||
@@ -97,15 +97,15 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// static constexpr index_t APackedSize =
|
||||
// ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// static constexpr index_t BPackedSize =
|
||||
// ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t AScalePackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<AScaleDataType>>::PackedSize;
|
||||
static constexpr index_t BScalePackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BScaleDataType>>::PackedSize;
|
||||
// static constexpr index_t AScalePackedSize =
|
||||
// ck_tile::numeric_traits<remove_cvref_t<AScaleDataType>>::PackedSize;
|
||||
// static constexpr index_t BScalePackedSize =
|
||||
// ck_tile::numeric_traits<remove_cvref_t<BScaleDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using AScaleLayout = remove_cvref_t<typename Problem::AScaleLayout>;
|
||||
@@ -122,6 +122,13 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
|
||||
static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize;
|
||||
|
||||
using MXdlPack = PipelineImplBase::MXdlPack;
|
||||
using NXdlPack = PipelineImplBase::NXdlPack;
|
||||
using KXdlPack = PipelineImplBase::KXdlPack;
|
||||
|
||||
using APackedSize = PipelineImplBase::APackedSize;
|
||||
using BPackedSize = PipelineImplBase::BPackedSize;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
@@ -312,7 +319,10 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
|
||||
auto a_scale_copy_dram_window =
|
||||
Base::GetAScaleDramLoadWindow(a_scale_dram_block_window_tmp);
|
||||
auto b_scale_copy_dram_window =
|
||||
Base::GetBScaleDramLoadWindow(b_scale_dram_block_window_tmp);
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
@@ -102,9 +102,46 @@ template <typename BlockGemmShape,
|
||||
index_t MXdlPack,
|
||||
index_t NXdlPack,
|
||||
index_t KXdlPack,
|
||||
index_t VecSize>
|
||||
index_t VecSize = 1>
|
||||
struct TileDistributionEncodingPatternAScale : public TileDistributionEncodingPattern
|
||||
{
|
||||
}
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / warp_size;
|
||||
|
||||
static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{});
|
||||
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{});
|
||||
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{});
|
||||
|
||||
static constexpr index_t MThreadPerXdl = WarpGemm::kM;
|
||||
static constexpr index_t KThreadPerXdl = warp_size / MThreadPerXdl;
|
||||
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps, "Block warps do not match block size");
|
||||
static_assert(KWarps == 1, "KWarps > 1 is not supported");
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t MXdlPack = 2; // MXdlPack is always 2
|
||||
static constexpr index_t Y1 = MWarps;
|
||||
static constexpr index_t Y2 = MThreadPerXdl;
|
||||
static constexpr index_t Y0 = YPerTile / MXdlPack / (MWarps * MThreadPerXdl);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = KThreadPerXdl;
|
||||
static constexpr index_t X1 = VecSize;
|
||||
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, "Y dimensions must cover the YPerTile");
|
||||
static_assert(X0 * X1 == XPerTile, "X dimensions must cover the XPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user