mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Multiple-ABD GEMM example (#2788)
* Multi ABD - initial commit * Clang-foramt fix * block gemm, unify the name of CDataType * Apply chnages to mem-pipeline * Rollback prefix for DType and Layout * Gemm Kernel Basic, rename * WMMA config * Grouped GEMM * Clang-format * Dropout, name * Review v2 * Move element_wise fn to unnary, remov old ones fn * clang-format * Fix issue review * WP operator adjust to universal gemm * v2 prepare * Remove unused comment * Remove vectorsize * Rollback * Adjust pipeline for abd * Shuffle argument * CI-fail fix quant * Fix ag_br pipeline * Failing tests * Typo * Single argument support
This commit is contained in:
@@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note This function is a modification of the existing load function.
|
||||
* It has been extended with two additional parameters: it takes a tuple as input
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should works with unknow number of params
|
||||
// Load element_wise API works only when the input typle is a tuple-tyupe
|
||||
return tile_window[number<0>{}].load(
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
@@ -120,6 +120,116 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note Load tile with elementwise — during value loading, an
|
||||
* elementwise function is executed for each A0, A1, … AN.
|
||||
* The values A0, A1, … AN are read by the same thread. In this way, we
|
||||
* reduce the amount of information loaded into the registers.
|
||||
* The same thread, during vectorized reading, accesses the same set of
|
||||
* data from A0, A1, A2, … AN.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_window,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
constexpr auto sizeOfTuple = TileWindow_::size();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const auto idx_vec_value = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return tile_window[number<jj>{}]
|
||||
.get_bottom_tensor_view()
|
||||
.template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
},
|
||||
number<sizeOfTuple>{});
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
|
||||
t.template get_as<
|
||||
typename Base::DataType>()[j / Traits::PackedSize]...);
|
||||
},
|
||||
idx_vec_value);
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
@@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
using T = tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>;
|
||||
|
||||
static constexpr auto N = T::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
template <typename TileWindowWithStaticDistributionType,
|
||||
typename StepType,
|
||||
typename std::enable_if_t<
|
||||
is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
|
||||
CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
|
||||
{
|
||||
static constexpr auto N = TileWindowWithStaticDistributionType::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user