mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Dev/a8w4 and a8w8splitk (#3447)
* Ck moe bs splitk pr (#3440) * splitk kick-off. Compilation fail * splitk hack pass * fix scale offset calc. * clang-format for a8w8_moe_blk_gemm1 splitk change * fix testcase error --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> * Zan/moe a8w4 (#3441) * update * update * update ck moe a8w4 * update * update * update * compile pass * update * update * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * support new a8w4 kernel * update * update ck_tile * re format * update * update * fix conflict * fix build * update ck_tile moe * fix clang format * fix the problem * fix accruacy issue * fix --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Co-authored-by: Zzz9990 <zanzhang@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -217,6 +217,7 @@ struct MoeFlatmmKernel
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
@@ -241,12 +242,24 @@ struct MoeFlatmmKernel
|
||||
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
|
||||
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
static constexpr int MXFP4K_Pack = 2;
|
||||
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, pk_fp4_t>;
|
||||
static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
|
||||
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
|
||||
static constexpr bool MXF8F6F4MFMA =
|
||||
#ifdef __gfx950__
|
||||
AQUANT_Pipeline && BMXFP4_Pipeline;
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
static constexpr int MXFP4M_Pack = 2;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
static constexpr int MXFP4K_Pack = 2;
|
||||
|
||||
static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1;
|
||||
static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1;
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
@@ -659,23 +672,95 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n;
|
||||
constexpr int GranularityK = decltype(scale_n)::GranularityK;
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
auto scale_m_desc = kargs.scale_m;
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0
|
||||
? 1
|
||||
: decltype(scale_m_desc)::GranularityK;
|
||||
|
||||
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int AGranularityK = 32;
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
|
||||
make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
|
||||
make_tuple(scale_k_packs * KThreadPerXdl, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
|
||||
const auto scale_b_flat_view = [&]() {
|
||||
auto scale_n = kargs.scale_n;
|
||||
constexpr int BGranularityK =
|
||||
decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
index_t scale_k =
|
||||
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
|
||||
constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
|
||||
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_n.ptr) +
|
||||
expert_id * kargs.N * scale_k / 4,
|
||||
scale_b_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t scale_k =
|
||||
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
|
||||
using ScaleType = e8m0_t;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
c_tensor_view,
|
||||
scale_a_tensor_view,
|
||||
scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -718,7 +803,7 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
|
||||
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3), views.at(I4));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -747,7 +832,7 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
|
||||
constexpr bool isNonInterleaveGateUp = !IsGateUp || BMXFP4_Pipeline;
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
@@ -766,17 +851,40 @@ struct MoeFlatmmKernel
|
||||
output_N_offset});
|
||||
|
||||
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
auto a_scale_block_window = make_tile_window(
|
||||
views.at(I3),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{coord_m / M_Pack, 0});
|
||||
|
||||
constexpr int XDLPerLoadScaleB =
|
||||
MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
|
||||
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
|
||||
|
||||
auto scale_block_window =
|
||||
make_tile_window(views.at(I3),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
|
||||
XDLPerLoadScaleB / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
auto b_scale_block_window = [&]() {
|
||||
if constexpr(MXF8F6F4MFMA)
|
||||
{
|
||||
return make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / N_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{coord_n / N_Pack, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
|
||||
XDLPerLoadScaleB / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
c_block_window,
|
||||
a_scale_block_window,
|
||||
b_scale_block_window);
|
||||
}
|
||||
|
||||
template <class MoeFlatmmKernelArgs>
|
||||
@@ -831,7 +939,6 @@ struct MoeFlatmmKernel
|
||||
|
||||
if(coord_m >= max_token_id)
|
||||
return;
|
||||
|
||||
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
|
||||
const auto row_idx =
|
||||
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
|
||||
@@ -864,9 +971,10 @@ struct MoeFlatmmKernel
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& a_scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& b_scale_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
auto a_gather_block_tile =
|
||||
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
|
||||
@@ -876,17 +984,32 @@ struct MoeFlatmmKernel
|
||||
a_offsets); // K DRAM tile window for
|
||||
|
||||
auto c_block_tile = [&] {
|
||||
if constexpr(MXFP4_Pipeline)
|
||||
if constexpr(BMXFP4_Pipeline)
|
||||
{
|
||||
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
|
||||
// BMXFP4_Pipeline uses gate-up interleave 16 layout for weight
|
||||
// so don't need extra processing
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
kargs.k_padded_zeros,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
return FlatmmPipeline{}(
|
||||
a_gather_block_tile,
|
||||
b_block_window,
|
||||
a_scale_block_window, // weight scale with granularityK = 32
|
||||
b_scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FlatmmPipeline{}(
|
||||
a_gather_block_tile,
|
||||
b_block_window,
|
||||
b_scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
kargs.k_padded_zeros,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -964,7 +1087,7 @@ struct MoeFlatmmKernel
|
||||
constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
|
||||
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
@@ -1059,7 +1182,7 @@ struct MoeFlatmmKernel
|
||||
number<1>{});
|
||||
|
||||
auto exp_bias_window = make_tile_window(
|
||||
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
|
||||
permute_tensor_view(exp_bias_view, number<(BMXFP4_Pipeline && !IsInputGemm)>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock > {}),
|
||||
@@ -1101,7 +1224,7 @@ struct MoeFlatmmKernel
|
||||
ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
|
||||
ExpWeightBuffer exp_weight_buffer;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
{
|
||||
scale_m_window.load(scale_m_buffer);
|
||||
scale_n_buffer = load_tile(scale_n_window);
|
||||
@@ -1233,7 +1356,7 @@ struct MoeFlatmmKernel
|
||||
auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
{
|
||||
gate_tensor.get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
@@ -1260,7 +1383,7 @@ struct MoeFlatmmKernel
|
||||
auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
if constexpr(EnableBias)
|
||||
|
||||
Reference in New Issue
Block a user