mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
use v4i32 as the storage type for B to avoid repack operation
This commit is contained in:
@@ -595,18 +595,25 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
scale_b_flat_window.get_window_origin(),
|
||||
scale_b_flat_distribution);
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
|
||||
union UnionB
|
||||
{
|
||||
V4UInt_Buffer u = 0;
|
||||
MXFP4_Buffer mxfp4;
|
||||
} ub;
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
statically_indexed_array<
|
||||
@@ -652,7 +659,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -733,11 +742,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4x4_to_compute_v2(
|
||||
reinterpret_cast<const thread_buffer<uint32_t, XDL_PerWeightK>&>(
|
||||
quant_weight_tensor)
|
||||
.get(quant_idx_k),
|
||||
bit_cast<float>(uscale),
|
||||
i));
|
||||
quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
|
||||
});
|
||||
#else
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
@@ -758,7 +763,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
|
||||
bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
|
||||
.at(i),
|
||||
bit_cast<float>(uscale)));
|
||||
});
|
||||
#endif
|
||||
@@ -798,7 +804,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -901,7 +908,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1008,7 +1016,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user