use v4i32 as the storage type for B to avoid repack operation

This commit is contained in:
Feng Shijie
2025-08-20 13:53:32 +00:00
parent 81899bd920
commit 9fbcc8f8a4

View File

@@ -595,18 +595,25 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
scale_b_flat_window.get_window_origin(),
scale_b_flat_distribution);
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
union UnionB
{
V4UInt_Buffer u = 0;
MXFP4_Buffer mxfp4;
} ub;
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
statically_indexed_array<
@@ -652,7 +659,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
});
});
// move B window to next flat K
@@ -733,11 +742,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4x4_to_compute_v2(
reinterpret_cast<const thread_buffer<uint32_t, XDL_PerWeightK>&>(
quant_weight_tensor)
.get(quant_idx_k),
bit_cast<float>(uscale),
i));
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
});
#else
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
@@ -758,7 +763,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
i,
pk_mxfp4_to_compute_v2(
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
.at(i),
bit_cast<float>(uscale)));
});
#endif
@@ -798,7 +804,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
});
});
@@ -901,7 +908,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
});
});
@@ -1008,7 +1016,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
});
});