This commit is contained in:
Feng Shijie
2025-08-11 11:24:34 +00:00
parent 200a11afc8
commit edb58d0680
7 changed files with 112 additions and 50 deletions

View File

@@ -157,12 +157,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view =
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
@@ -297,7 +297,11 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_block_window);
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -326,7 +330,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& scale_block_window = gemm_tile_windows.at(I4);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
b_flat_block_window,
scale_block_window,

View File

@@ -588,8 +588,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
constexpr int ScaleB_BlockK =
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
constexpr int ScaleB_BlockK = 16 * 2 * 4;
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
auto scale_b_flat_dram_window = make_tile_window(
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
@@ -640,8 +640,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
@@ -690,6 +691,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
#if defined(__gfx942__)
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
lane_scale);
return lane_scale;
#endif
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
@@ -705,12 +708,13 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
if constexpr(xdl_k_idx % 2 == 0)
{
return v2scale[0];
lane_scale = v2scale[0];
}
else
{
return v2scale[1];
lane_scale = v2scale[1];
}
return lane_scale;
};
auto deq_fn = [&](const auto& quant_weight_tensor,
@@ -721,15 +725,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
packed_scale = perm_scale(packed_scale, b_idx_k);
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
auto use_scale = scale;
use_scale.data = perm_scale(scale.data, b_idx_k);
if constexpr(xdl_nIter % 2 != 0)
{
scale_ptr++;
}
if constexpr(xdl_nIter == 0)
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
{
printf("laneid = %2u xdl-k=%2d use-scale = "
"%.2f\n",
threadIdx.x,
int(xdl_kIter),
float(use_scale));
}
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
@@ -737,7 +746,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
number<i>{},
pk_fp4_to_fp16x2(
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
*scale_ptr));
static_cast<float>(use_scale)));
});
};
@@ -748,6 +757,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(2i+1)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -828,6 +851,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(2i+2)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -910,6 +947,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B(loopK)
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % NRepeatPerScaleLoad == 0)
{
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
move_tile_window(
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
}
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),