[CK_TILE] Multiple-ABD GEMM example (#2788)

* Multi ABD - initial commit

* Clang-foramt fix

* block gemm, unify the name of CDataType

* Apply chnages to mem-pipeline

* Rollback prefix for DType and Layout

* Gemm Kernel Basic, rename

* WMMA config

* Grouped GEMM

* Clang-format

* Dropout, name

* Review v2

* Move element_wise fn to unnary, remov old ones fn

* clang-format

* Fix issue review

* WP operator adjust to universal gemm

* v2 prepare

* Remove unused comment

* Remove vectorsize

* Rollback

* Adjust pipeline for abd

* Shuffle argument

* CI-fail fix quant

* Fix ag_br pipeline

* Failing tests

* Typo

* Single argument support
This commit is contained in:
Mateusz Ozga
2025-09-19 01:14:11 +02:00
committed by GitHub
parent 14bbc545ea
commit 30ab1d6a71
41 changed files with 3603 additions and 552 deletions

View File

@@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
* @brief Load tile with elementwise function
*
* @note This function is a modification of the existing load function.
* It has been extended with two additional parameters: it takes a tuple as input
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,

View File

@@ -120,6 +120,116 @@ struct tile_window_with_static_distribution
return dst_tensor;
}
/**
* @brief Load tile with elementwise function
*
* @note Load tile with elementwise — during value loading, an
* elementwise function is executed for each A0, A1, … AN.
* The values A0, A1, … AN are read by the same thread. In this way, we
* reduce the amount of information loaded into the registers.
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
bool_constant<oob_conditional_check>{});
},
number<sizeOfTuple>{});
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
ck_tile::apply(
[&](auto&&... t) {
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
t.template get_as<
typename Base::DataType>()[j / Traits::PackedSize]...);
},
idx_vec_value);
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
@@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE void move_tile_window(
tuple<tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>>& window,
const typename tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{
using T = tuple<tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>>;
static constexpr auto N = T::size();
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
}
template <typename TileWindowWithStaticDistributionType,
typename StepType,
typename std::enable_if_t<
is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
{
static constexpr auto N = TileWindowWithStaticDistributionType::size();
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
}
/**
* @brief This class provides description of tile windowed view on the device memory.
*