Dev/a8w4 and a8w8splitk (#3447)

* Ck moe bs splitk pr (#3440)

* splitk kick-off. Compilation fail

* splitk hack pass

* fix scale offset calc.

* clang-format for a8w8_moe_blk_gemm1 splitk change

* fix testcase error

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>

* Zan/moe a8w4 (#3441)

* update

* update

* update ck moe a8w4

* update

* update

* update

* compile pass

* update

* update

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* support new a8w4 kernel

* update

* update ck_tile

* re format

* update

* update

* fix conflict

* fix build

* update ck_tile moe

* fix clang format

* fix the problem

* fix accruacy issue

* fix

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
yadaish
2025-12-19 09:26:52 +08:00
committed by GitHub
parent ba897f8435
commit c0ee71d735
13 changed files with 2911 additions and 139 deletions

View File

@@ -218,6 +218,44 @@ struct tile_scatter_gather
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
auto partition_index = get_partition_index(tile_distribution);
auto use_lane_id_0 = partition_index;
use_lane_id_0[1] = 0;
const auto window_adaptor_thread_coord_tmp_warp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(use_lane_id_0, array<index_t, NDimY>{0}));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp_warp =
window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp_warp =
make_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_origin_idx_tmp_warp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_warp_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
@@ -602,6 +640,135 @@ struct tile_scatter_gather
});
}
// TODO: fix with swizzle
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<static_move_ys> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto idx_ys_offset = [&]() {
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
container_concat(array<index_t, NDimP>{0},
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
return adapter_ys_offset.get_bottom_index();
}();
const auto lds_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset =
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Use precomputed window origin & tensor descriptor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_warp_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
lds_coord.get_offset() / Traits::PackedSize +
lds_ys_offset / Traits::PackedSize;
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset = make_tensor_coordinate(
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
offset + dram_ys_offset,
bool_constant<oob_conditional_check>{});
}
else
{
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
offset + dram_ys_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
// Move thread coordinate if not last access
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
if constexpr(!static_move_ys)
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord,
bottom_tensor_thread_coord,
idx_diff_ps_ys);
if constexpr(!static_move_ys)
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
@@ -788,6 +955,15 @@ struct tile_scatter_gather
pre_computed_coords_(iCoord)(I1),
step_new);
});
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_warp_coords_(iCoord)(I1),
step_new);
});
}
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
@@ -892,6 +1068,11 @@ struct tile_scatter_gather
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
std::conditional_t<BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global,
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord>,
std::byte>
pre_computed_warp_coords_;
};
// TODO: use strategy
@@ -906,7 +1087,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
const StaticPageIndexArray_& page_idx, // perbytes
number<HsGatherDim> = {},
number<NumCoord> = {})
{