mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fix settings for example, fix some things in pipeline
This commit is contained in:
@@ -95,6 +95,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
/// @brief The e8m0 scales are packed into int32/float32 such that
|
||||
/// in one element contains a 2x2 block of scales (two rows, two lements in K dim)
|
||||
static constexpr auto MXdlPack = MXGemmPipeline::MXdlPack;
|
||||
static constexpr auto NXdlPack = MXGemmPipeline::NXdlPack;
|
||||
static constexpr auto KXdlPack = MXGemmPipeline::KXdlPack;
|
||||
@@ -195,7 +197,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK, "M and N scales must have same K granularity!");
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
@@ -218,10 +221,10 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// B scale tensor view
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
scale_b_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -251,12 +254,14 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
// We are packing 2x2 (MXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_m / MXdlPack, 0});
|
||||
|
||||
// We are packing 2x2 (NXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
|
||||
@@ -295,7 +300,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
@@ -304,12 +309,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window,
|
||||
b_flat_block_window,
|
||||
b_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
@@ -317,54 +319,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto scale_m_ptr_offset = kargs.scale_m_ptr + block_idx_m;
|
||||
auto scale_n_ptr_offset = kargs.scale_n_ptr + block_idx_n;
|
||||
|
||||
auto scale_m_view = [&]() {
|
||||
if constexpr (ScaleM::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_m_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<1>{}, number<0>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n_view = [&]() {
|
||||
if constexpr (ScaleN::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_n_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<0>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
|
||||
Reference in New Issue
Block a user