mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
update
This commit is contained in:
@@ -588,8 +588,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
constexpr int ScaleB_BlockK =
|
||||
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
constexpr int ScaleB_BlockK = 16 * 2 * 4;
|
||||
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -640,8 +640,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
@@ -690,6 +691,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
|
||||
#if defined(__gfx942__)
|
||||
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
|
||||
lane_scale);
|
||||
return lane_scale;
|
||||
#endif
|
||||
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
|
||||
@@ -705,12 +708,13 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx % 2 == 0)
|
||||
{
|
||||
return v2scale[0];
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return v2scale[1];
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
return lane_scale;
|
||||
};
|
||||
|
||||
auto deq_fn = [&](const auto& quant_weight_tensor,
|
||||
@@ -721,15 +725,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
|
||||
|
||||
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
|
||||
packed_scale = perm_scale(packed_scale, b_idx_k);
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
|
||||
|
||||
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
|
||||
auto use_scale = scale;
|
||||
use_scale.data = perm_scale(scale.data, b_idx_k);
|
||||
|
||||
if constexpr(xdl_nIter % 2 != 0)
|
||||
{
|
||||
scale_ptr++;
|
||||
}
|
||||
if constexpr(xdl_nIter == 0)
|
||||
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
|
||||
{
|
||||
printf("laneid = %2u xdl-k=%2d use-scale = "
|
||||
"%.2f\n",
|
||||
threadIdx.x,
|
||||
int(xdl_kIter),
|
||||
float(use_scale));
|
||||
}
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
@@ -737,7 +746,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
number<i>{},
|
||||
pk_fp4_to_fp16x2(
|
||||
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
|
||||
*scale_ptr));
|
||||
static_cast<float>(use_scale)));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -748,6 +757,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -828,6 +851,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -910,6 +947,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(loopK)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
|
||||
Reference in New Issue
Block a user