mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
eliminate repeat dequant
This commit is contained in:
@@ -708,7 +708,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto dequant_B = typename WG::BWarpTensor{};
|
||||
statically_indexed_array<typename WG::BWarpTensor, NIterPerWarp> dequant_B_n;
|
||||
|
||||
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
@@ -743,7 +743,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B.get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
|
||||
@@ -811,14 +811,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -912,14 +914,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1014,14 +1018,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1068,14 +1074,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1114,13 +1122,15 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
|
||||
Reference in New Issue
Block a user