mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
update
This commit is contained in:
@@ -157,12 +157,12 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
const auto scale_b_flat_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(scale_n.ptr,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
|
||||
@@ -297,7 +297,11 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
|
||||
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
@@ -326,7 +330,7 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
|
||||
@@ -588,8 +588,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
constexpr int ScaleB_BlockK =
|
||||
flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
constexpr int ScaleB_BlockK = 16 * 2 * 4;
|
||||
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -640,8 +640,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
@@ -690,6 +691,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
|
||||
#if defined(__gfx942__)
|
||||
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
|
||||
lane_scale);
|
||||
return lane_scale;
|
||||
#endif
|
||||
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
|
||||
@@ -705,12 +708,13 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx % 2 == 0)
|
||||
{
|
||||
return v2scale[0];
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return v2scale[1];
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
return lane_scale;
|
||||
};
|
||||
|
||||
auto deq_fn = [&](const auto& quant_weight_tensor,
|
||||
@@ -721,15 +725,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
|
||||
|
||||
uint32_t packed_scale = scale_tensor.get_thread_buffer().template get_as<uint32_t>(I0);
|
||||
packed_scale = perm_scale(packed_scale, b_idx_k);
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
|
||||
|
||||
e8m0_t* scale_ptr = reinterpret_cast<e8m0_t*>(&packed_scale);
|
||||
auto use_scale = scale;
|
||||
use_scale.data = perm_scale(scale.data, b_idx_k);
|
||||
|
||||
if constexpr(xdl_nIter % 2 != 0)
|
||||
{
|
||||
scale_ptr++;
|
||||
}
|
||||
if constexpr(xdl_nIter == 0)
|
||||
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
|
||||
{
|
||||
printf("laneid = %2u xdl-k=%2d use-scale = "
|
||||
"%.2f\n",
|
||||
threadIdx.x,
|
||||
int(xdl_kIter),
|
||||
float(use_scale));
|
||||
}
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
@@ -737,7 +746,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
number<i>{},
|
||||
pk_fp4_to_fp16x2(
|
||||
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
|
||||
*scale_ptr));
|
||||
static_cast<float>(use_scale)));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -748,6 +757,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -828,6 +851,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -910,6 +947,20 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// prefetch B(loopK)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
|
||||
Reference in New Issue
Block a user