mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
support a16_wint4 moe
This commit is contained in:
@@ -86,7 +86,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
// TODO(yadai): rename to W4_Pipeline
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t> | std::is_same_v<BDataType, ck_tile::pk_int4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
@@ -444,6 +445,22 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "fp16xint4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xint4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_int4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
|
||||
@@ -151,6 +151,16 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
|
||||
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
|
||||
}
|
||||
|
||||
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale)
|
||||
{
|
||||
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
|
||||
float_vec2.x = float_vec2.x * scale;
|
||||
float_vec2.y = float_vec2.y * scale;
|
||||
return fp16x2_t{type_convert<fp16_t>(float_vec2.x), type_convert<fp16_t>(float_vec2.y)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
@@ -166,6 +176,14 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale)
|
||||
{
|
||||
auto float_vec2 = pk_int4_t_to_fp32x2_t(x);
|
||||
float_vec2.x = float_vec2.x * scale;
|
||||
float_vec2.y = float_vec2.y * scale;
|
||||
return bf16x2_t{type_convert<bf16_t>(float_vec2.x), type_convert<bf16_t>(float_vec2.y)};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
@@ -241,7 +241,7 @@ struct MoeFlatmmKernel
|
||||
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
|
||||
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t> || std::is_same_v<BDataType, pk_int4_t>;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
static constexpr int MXFP4K_Pack = 2;
|
||||
|
||||
|
||||
@@ -187,6 +187,135 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
struct DequantizeMxFP4 {
|
||||
|
||||
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
|
||||
const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
constexpr int float_mantissa = 23;
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
|
||||
pk_mxfp4x4, fscale, int(byte_idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4x4_to_compute_v2(
|
||||
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
|
||||
});
|
||||
#else
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct DequantizeINT4 {
|
||||
|
||||
CK_TILE_DEVICE auto operator()(statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp>& dequant_B_n,
|
||||
const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
constexpr int float_mantissa = 23;
|
||||
|
||||
uint32_t uscale = uint32_t(scale.data) << float_mantissa;
|
||||
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
auto pk_int4_to_compute_v2 = [](auto pk_int4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return pk_int4_t_to_halfx2_t(pk_int4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return pk_int4_t_to_bfloat16x2_t(pk_int4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_int4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_int4_to_compute_v2(
|
||||
bit_cast<thread_buffer<pk_int4_t, 4>>(quant_weight_tensor[quant_idx_k])
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
using DequantOp = typename std::conditional<std::is_same_v<BDataType, ck_tile::pk_fp4_t>, DequantizeMxFP4, DequantizeINT4>::type;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
@@ -747,6 +876,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
|
||||
|
||||
|
||||
/*
|
||||
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
@@ -816,6 +947,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
#endif
|
||||
};
|
||||
*/
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
@@ -877,7 +1009,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
DequantOp{}(
|
||||
dequant_B_n,
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
@@ -997,7 +1130,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
DequantOp{}(
|
||||
dequant_B_n,
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
@@ -1124,7 +1258,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
DequantOp{}(
|
||||
dequant_B_n,
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
@@ -1185,7 +1320,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
DequantOp{}(
|
||||
dequant_B_n,
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
@@ -1236,7 +1372,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
DequantOp{}(
|
||||
dequant_B_n,
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
|
||||
Reference in New Issue
Block a user