mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Dev/a8w4 and a8w8splitk (#3447)
* Ck moe bs splitk pr (#3440) * splitk kick-off. Compilation fail * splitk hack pass * fix scale offset calc. * clang-format for a8w8_moe_blk_gemm1 splitk change * fix testcase error --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> * Zan/moe a8w4 (#3441) * update * update * update ck moe a8w4 * update * update * update * compile pass * update * update * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * support new a8w4 kernel * update * update ck_tile * re format * update * update * fix conflict * fix build * update ck_tile moe * fix clang format * fix the problem * fix accruacy issue * fix --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Co-authored-by: Zzz9990 <zanzhang@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -218,6 +218,44 @@ struct tile_scatter_gather
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
{
|
||||
auto partition_index = get_partition_index(tile_distribution);
|
||||
|
||||
auto use_lane_id_0 = partition_index;
|
||||
use_lane_id_0[1] = 0;
|
||||
const auto window_adaptor_thread_coord_tmp_warp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(use_lane_id_0, array<index_t, NDimY>{0}));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp_warp =
|
||||
window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index();
|
||||
bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0;
|
||||
const auto bottom_tensor_thread_coord_tmp_warp =
|
||||
make_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_origin_idx_tmp_warp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_warp_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
@@ -602,6 +640,135 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: fix with swizzle
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
|
||||
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
|
||||
LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<static_move_ys> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// Precompute invariant values outside loops
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
|
||||
auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
constexpr auto idx_ys_offset = [&]() {
|
||||
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
|
||||
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
|
||||
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(array<index_t, NDimP>{0},
|
||||
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
|
||||
return adapter_ys_offset.get_bottom_index();
|
||||
}();
|
||||
const auto lds_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset =
|
||||
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
// Use precomputed window origin & tensor descriptor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_warp_coord.get_bottom_index();
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
|
||||
lds_coord.get_offset() / Traits::PackedSize +
|
||||
lds_ys_offset / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset = make_tensor_coordinate(
|
||||
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
|
||||
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
|
||||
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
mixed_bottom_thread_coord,
|
||||
offset + dram_ys_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
mixed_bottom_thread_coord,
|
||||
offset + dram_ys_offset,
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
if constexpr(!static_move_ys)
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord,
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_ps_ys);
|
||||
|
||||
if constexpr(!static_move_ys)
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
@@ -788,6 +955,15 @@ struct tile_scatter_gather
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step_new);
|
||||
});
|
||||
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
{
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_warp_coords_(iCoord)(I1),
|
||||
step_new);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
@@ -892,6 +1068,11 @@ struct tile_scatter_gather
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
std::conditional_t<BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global,
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord>,
|
||||
std::byte>
|
||||
pre_computed_warp_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
@@ -906,7 +1087,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
const StaticPageIndexArray_& page_idx, // perbytes
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user