support a16_wint4 moe

This commit is contained in:
yadaish
2025-11-19 04:13:41 +00:00
parent f05d4a2fed
commit 79f2db722e
4 changed files with 179 additions and 7 deletions

View File

@@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
// TODO(yadai): rename to W4_Pipeline
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t> | std::is_same_v<BDataType, ck_tile::pk_int4_t>;
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
@@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "fp16xint4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_int4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xint4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_int4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");

View File

@@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale)
{
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
float_vec2.x = float_vec2.x * scale;
float_vec2.y = float_vec2.y * scale;
return fp16x2_t{type_convert<fp16_t>(float_vec2.x), type_convert<fp16_t>(float_vec2.y)};
}
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
@@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale)
{
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
float_vec2.x = float_vec2.x * scale;
float_vec2.y = float_vec2.y * scale;
return bf16x2_t{type_convert<bf16_t>(float_vec2.x), type_convert<bf16_t>(float_vec2.y)};
}
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);

View File

@@ -241,7 +241,7 @@ struct MoeFlatmmKernel
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t> || std::is_same_v<BDataType, pk_int4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;

View File

@@ -187,6 +187,135 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
struct DequantizeMxFP4 {
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
constexpr int float_mantissa = 23;
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
#if defined(__gfx950__)
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
pk_mxfp4x4, fscale, int(byte_idx));
}
else
{
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4x4_to_compute_v2(
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
});
#else
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4_to_compute_v2(
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
.at(i),
bit_cast<float>(uscale)));
});
#endif
return 0;
}
};
struct DequantizeINT4 {
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
auto xdl_kIter) {
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
constexpr int float_mantissa = 23;
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
using ComputeV2Type =
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
auto pk_int4_to_compute_v2 = [](auto pk_int4, float fscale) {
if constexpr(std::is_same_v<ComputeType, half_t>)
{
return pk_int4_t_to_halfx2_t(pk_int4, fscale);
}
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
{
return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale);
}
else
{
static_assert(sizeof(pk_int4) == 0, "unsupported compute type");
}
};
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_int4_to_compute_v2(
bit_cast<thread_buffer<pk_int4_t, 4>>(quant_weight_tensor[quant_idx_k])
.at(i),
bit_cast<float>(uscale)));
});
return 0;
}
};
using DequantOp = typename std::conditional<std::is_same_v<BDataType, ck_tile::pk_fp4_t>, DequantizeMxFP4, DequantizeINT4>::type;
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
@@ -747,6 +876,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
/*
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
auto xdl_nIter,
@@ -816,6 +947,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
#endif
};
*/
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
@@ -877,7 +1009,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
DequantOp{}(
dequant_B_n,
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
@@ -997,7 +1130,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
DequantOp{}(
dequant_B_n,
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
@@ -1124,7 +1258,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
DequantOp{}(
dequant_B_n,
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
@@ -1185,7 +1320,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
DequantOp{}(
dequant_B_n,
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
@@ -1236,7 +1372,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
if constexpr(mIter == 0)
dequant_mxfp4(
DequantOp{}(
dequant_B_n,
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),