eliminate repeat dequant

This commit is contained in:
Feng Shijie
2025-08-14 09:34:12 +00:00
parent 53e8c0c533
commit 714b341797

View File

@@ -708,7 +708,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
__builtin_amdgcn_sched_barrier(0);
auto dequant_B = typename WG::BWarpTensor{};
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
const auto& scale_tensor,
@@ -743,7 +743,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
static_for<0, PackedCnt, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<ComputeV2Type>(
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4_to_compute_v2(
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
@@ -811,14 +811,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -912,14 +914,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1014,14 +1018,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1068,14 +1074,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1114,13 +1122,15 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(