paged window run ok

This commit is contained in:
coderfeli
2025-04-05 13:45:46 +00:00
parent 45a0463f1f
commit fe2ea699e5
10 changed files with 805 additions and 116 deletions

View File

@@ -533,11 +533,6 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")

View File

@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt

View File

@@ -254,22 +254,22 @@ int run_moe_gemm_example_with_layouts(int argc,
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
for(int im = 0; im < M; im++)
{
for(int in = 0; in < N; in++)
{
// if (static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]) != 0)
printf("c[%d][%d]: %f ",
im,
in,
static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]));
printf("ref[%d][%d]: %f \n",
im,
in,
static_cast<float>(
static_cast<CDataType*>(c_m_n_host_ref.data())[im * N + in]));
}
}
// for(int im = 0; im < M; im++)
// {
// for(int in = 0; in < N; in++)
// {
// // if (static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]) != 0)
// printf("c[%d][%d]: %f ",
// im,
// in,
// static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]));
// printf("ref[%d][%d]: %f \n",
// im,
// in,
// static_cast<float>(
// static_cast<CDataType*>(c_m_n_host_ref.data())[im * N + in]));
// }
// }
pass = ck_tile::check_err(c_m_n_tensor,
c_m_n_host_ref,

View File

@@ -104,30 +104,30 @@ CK_TILE_DEVICE void store_tile(
tile_window.store(dstr_tensor, number<-1>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
const T& offsets)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
// template <typename T,
// typename BottomTensorView_,
// typename WindowLengths_,
// typename TileDistribution_,
// typename DataType_>
// CK_TILE_DEVICE void
// store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
// const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
// const T& offsets)
// {
// using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
// using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
// static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
// constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
// auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
// tile_window_tmp.get_window_lengths(),
// tile_window_tmp.get_window_origin(),
// tile_dstr);
tile_window.store(dstr_tensor, offsets);
}
// tile_window.store(dstr_tensor, offsets);
// }
template <typename BottomTensorView_,
typename WindowLengths_,

View File

@@ -609,92 +609,92 @@ struct tile_window_with_static_distribution
});
}
template <typename statically_indexed_array,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
const statically_indexed_array offsets,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// template <typename statically_indexed_array,
// index_t i_access_unsupport_ = -1,
// bool oob_conditional_check = true>
// CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
// const statically_indexed_array& offsets,
// number<i_access_unsupport_> = {},
// bool_constant<oob_conditional_check> = {}) const
// {
// using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// // using vector_type_t = typename Traits::vector_type_t;
// using vector_t = typename Traits::vector_t;
// using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
// auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
// // loop over thread tensor space [y0, y1, ...]
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
// // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ +
tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ +
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
auto bottom_tensor_thread_coord = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
// constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_m = idx_ys_start[number<0>{}];
const auto offset = offsets[idx_m];
// // data index [y0, y1, ...]
// constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// constexpr auto idx_m = idx_ys_start[number<0>{}];
// const auto offset = offsets[idx_m];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
// // read from distributed tensor
// // vector_type_t vec;
// vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
// static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
// constexpr auto idx_ys = generate_tuple(
// [&](auto jj) {
// return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
// : idx_ys_start[jj];
// },
// number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// constexpr index_t d =
// tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
// Traits::PackedSize;
// // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
// vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
// dstr_tensor.get_thread_buffer().template at<d>();
// });
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
offset,
vec_value,
bool_constant<oob_conditional_check>{});
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
// // write into bottom tensor
// get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
// bottom_tensor_thread_coord,
// offset,
// vec_value,
// bool_constant<oob_conditional_check>{});
// // printf("coord_offset:%d, scatter_offset:%d \n",
// // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
// if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
// {
// constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
// constexpr auto forward_step_scatter = generate_tuple(
// [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
// constexpr auto idx_diff_ps_ys = container_concat(
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
// forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
// }
// });
// });
// }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,

View File

@@ -0,0 +1,686 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t NumCoord>
struct page_tile_with_static_distribution
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
page_tile_with_static_distribution::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr page_tile_with_static_distribution() = default;
CK_TILE_DEVICE constexpr page_tile_with_static_distribution(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_m = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_m];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, page_offset, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
printf("off %d\n", page_idx_[I0]);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ +
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_m = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_m];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step);
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
{
// window_origin_ += step;
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
// pre_computed_coords_(iCoord)(I1),
// step);
// });
page_idx_ = new_idx;
}
// CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
// {
// window_origin_ = new_window_origin;
// #if 0 // debug
// // TODO: this use more register for FA, but less register for GEMM
// // need investigation
// // only support warp-tile and block-tile
// static_assert(NDimP == 1 or NDimP == 2, "wrong!");
// WindowAdaptorCoord window_adaptor_thread_coord_tmp;
// if constexpr(NDimP == 1)
// {
// window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
// tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
// }
// else if constexpr(NDimP == 2)
// {
// window_adaptor_thread_coord_tmp =
// make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
// AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
// }
// #else
// // TODO: this use less register for FA, but more register for GEMM
// // need investigation
// const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
// tile_dstr_.get_ps_ys_to_xs_adaptor(),
// container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
// #endif
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
// const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// // future load/store() calls (might allocate more registers)
// using Traits = load_store_traits;
// using SFC_Ys = typename Traits::SFC_Ys;
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
// auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
// constexpr auto idx_diff_ys =
// SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
// constexpr auto idx_diff_ps_ys = container_concat(
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
// pre_computed_coords_(iCoord) =
// make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
// });
// }
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
PageIdxArray page_idx_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<NumCoord> = {})
{
return page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx};
}
// this version can't be called in a constexpr context
// template <typename TensorView_,
// typename WindowLengths_,
// typename StaticTileDistribution_,
// index_t NumCoord = 1>
// CK_TILE_DEVICE auto
// make_tile_window_raw(const TensorView_& tensor_view,
// const WindowLengths_& window_lengths,
// const multi_index<TensorView_::get_num_of_dimension()>& origin,
// const StaticTileDistribution_& tile_distribution,
// number<NumCoord> = {})
// {
// auto w = page_tile_with_static_distribution<remove_cvref_t<TensorView_>,
// remove_cvref_t<WindowLengths_>,
// remove_cvref_t<StaticTileDistribution_>,
// NumCoord>{
// tensor_view, window_lengths, origin, tile_distribution};
// w.init_raw();
// return w;
// }
// template <typename TensorView_,
// typename WindowLengths_,
// typename StaticTileDistribution_,
// index_t NumCoord>
// CK_TILE_DEVICE void move_tile_window(
// page_tile_with_static_distribution<TensorView_,
// WindowLengths_,
// StaticTileDistribution_,
// NumCoord>& window,
// const typename page_tile_with_static_distribution<TensorView_,
// WindowLengths_,
// StaticTileDistribution_,
// NumCoord>::BottomTensorIndex& step)
// {
// window.move(step);
// }
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx)
{
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray>
CK_TILE_DEVICE constexpr auto
make_tile_window_paged(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx)
{
return make_tile_window_paged(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx);
}
// template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
// CK_TILE_DEVICE constexpr auto
// make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
// const StaticTileDistribution& tile_distribution)
// {
// auto w = make_tile_window_paged(tile_window.get_bottom_tensor_view(),
// tile_window.get_window_lengths(),
// tile_window.get_window_origin(),
// tile_distribution);
// w.init_raw();
// return w;
// }
} // namespace ck_tile

View File

@@ -97,7 +97,7 @@ __global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
(void)Num_tokens;
// assert(p_sorted_expert_ids_ != nullptr);
// assert(TopK == 1);
// assert(Num_tokens == 128);

View File

@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/tensor/tile_window_paged.hpp"
namespace ck_tile {
template <typename ADataType_,
@@ -212,7 +212,16 @@ struct CShuffleEpilogue
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor, offsets);
auto tile_window = make_tile_window_paged(out_dram_window.get_bottom_tensor_view(),
out_dram_window.get_window_lengths(),
out_dram_window.get_window_origin(),
dram_tile_distribution,
offsets);
tile_window.store(c_out_tensor);
// store_tile(out_dram_window, c_out_tensor, offsets);
}
else
{

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/core/tensor/tile_window_paged.hpp"
namespace ck_tile {

View File

@@ -19,7 +19,6 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \