mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fix settings for example, fix some things in pipeline
This commit is contained in:
@@ -119,6 +119,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = AK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
@@ -95,6 +95,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
/// @brief The e8m0 scales are packed into int32/float32 such that
|
||||
/// in one element contains a 2x2 block of scales (two rows, two lements in K dim)
|
||||
static constexpr auto MXdlPack = MXGemmPipeline::MXdlPack;
|
||||
static constexpr auto NXdlPack = MXGemmPipeline::NXdlPack;
|
||||
static constexpr auto KXdlPack = MXGemmPipeline::KXdlPack;
|
||||
@@ -195,7 +197,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK, "M and N scales must have same K granularity!");
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
@@ -218,10 +221,10 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// B scale tensor view
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
scale_b_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
@@ -251,12 +254,14 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
// We are packing 2x2 (MXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_m / MXdlPack, 0});
|
||||
|
||||
// We are packing 2x2 (NXdlPack x KXdlPack) scales (e8m0) into one int32 element
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
|
||||
@@ -295,7 +300,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
@@ -304,12 +309,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
const auto& c_block_tile = MXGemmPipeline{}(a_block_window,
|
||||
b_flat_block_window,
|
||||
b_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
@@ -317,54 +319,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto scale_m_ptr_offset = kargs.scale_m_ptr + block_idx_m;
|
||||
auto scale_n_ptr_offset = kargs.scale_n_ptr + block_idx_n;
|
||||
|
||||
auto scale_m_view = [&]() {
|
||||
if constexpr (ScaleM::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_m_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<1>{}, number<0>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n_view = [&]() {
|
||||
if constexpr (ScaleN::GranularityMN != -1) {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
scale_n_ptr_offset.ptr,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(number<0>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{}
|
||||
);
|
||||
} else {
|
||||
return typename EpiloguePipeline::EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
|
||||
@@ -242,7 +242,7 @@ struct MXGemmPipelineAgBgCrV1
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
};
|
||||
|
||||
// Helper for Math Loop
|
||||
// Helper for Main Loop
|
||||
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
|
||||
// Define register tiles types for double buffering
|
||||
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));
|
||||
|
||||
@@ -227,31 +227,10 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<K_Thread / AK1, K_Lane, AK1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<2, 2>, // K_Thread/AK1, AK1
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
@@ -294,6 +273,29 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
|
||||
{
|
||||
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t N2 = WaveSize / K1; // 8
|
||||
constexpr index_t N1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = BK1; // f4=32; f8=16
|
||||
|
||||
Reference in New Issue
Block a user