diff --git a/CMakeLists.txt b/CMakeLists.txt index bb0c254e06..b7337a7f83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -533,11 +533,6 @@ include_directories(BEFORE ${HIP_INCLUDE_DIRS} ) -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) -endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..d5bcd6f978 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,6 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc index db049c9cc8..5b1e838fd1 100644 --- a/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc +++ b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc @@ -254,22 +254,22 @@ int run_moe_gemm_example_with_layouts(int argc, K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); - for(int im = 0; im < M; im++) - { - for(int in = 0; in < N; in++) - { - // if (static_cast(static_cast(p_c)[im * N + in]) != 0) - printf("c[%d][%d]: %f ", - im, - in, - static_cast(static_cast(p_c)[im * N + in])); - printf("ref[%d][%d]: %f \n", - im, - in, - static_cast( - static_cast(c_m_n_host_ref.data())[im * N + in])); - } - } + // for(int im = 0; im < M; im++) + // { + // for(int in = 0; in < N; in++) + // { + // // if (static_cast(static_cast(p_c)[im * N + in]) != 0) + // printf("c[%d][%d]: %f ", + // im, + // in, + // static_cast(static_cast(p_c)[im * N + in])); + // printf("ref[%d][%d]: %f \n", + // im, + // in, + // static_cast( + // static_cast(c_m_n_host_ref.data())[im * N + in])); + // } + // } pass = ck_tile::check_err(c_m_n_tensor, c_m_n_host_ref, diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d363622c3d..fa7eb7b089 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -104,30 +104,30 @@ CK_TILE_DEVICE void store_tile( tile_window.store(dstr_tensor, number<-1>{}); } -template -CK_TILE_DEVICE void -store_tile(tile_window_with_static_lengths& tile_window_tmp, - const static_distributed_tensor& dstr_tensor, - const T& offsets) -{ - using DataType = remove_cvref_t; - using TileDstr = remove_cvref_t; +// template +// CK_TILE_DEVICE void +// store_tile(tile_window_with_static_lengths& tile_window_tmp, +// const static_distributed_tensor& dstr_tensor, +// const T& offsets) +// { +// using DataType = remove_cvref_t; +// using TileDstr = remove_cvref_t; - static_assert(std::is_same_v, DataType>, "wrong!"); +// static_assert(std::is_same_v, DataType>, "wrong!"); - constexpr auto tile_dstr = TileDstr{}; +// constexpr auto tile_dstr = TileDstr{}; - auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), - tile_window_tmp.get_window_lengths(), - tile_window_tmp.get_window_origin(), - tile_dstr); +// auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), +// tile_window_tmp.get_window_lengths(), +// tile_window_tmp.get_window_origin(), +// tile_dstr); - tile_window.store(dstr_tensor, offsets); -} +// tile_window.store(dstr_tensor, offsets); +// } template - CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, - const statically_indexed_array offsets, - number = {}, - bool_constant = {}) const - { - using Traits = load_store_traits; + // template + // CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + // const statically_indexed_array& offsets, + // number = {}, + // bool_constant = {}) const + // { + // using Traits = load_store_traits; - // using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename Traits::vector_t; - using SFC_Ys = typename Traits::SFC_Ys; + // // using vector_type_t = typename Traits::vector_type_t; + // using vector_t = typename Traits::vector_t; + // using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + // constexpr auto tile_dstr = TileDstr{}; - // loop over thread tensor space [y0, y1, ...] - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; - // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + // // loop over thread tensor space [y0, y1, ...] + // static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + // // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin_ + - tuple(0, window_adaptor_thread_coord.get_bottom_index()[1]); + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // window_origin_ + + // tuple(0, window_adaptor_thread_coord.get_bottom_index()[1]); - auto bottom_tensor_thread_coord = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + // auto bottom_tensor_thread_coord = make_tensor_coordinate( + // bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); - static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + // static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + // constexpr auto iAccess = number{}; - // data index [y0, y1, ...] - constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_m = idx_ys_start[number<0>{}]; - const auto offset = offsets[idx_m]; + // // data index [y0, y1, ...] + // constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + // constexpr auto idx_m = idx_ys_start[number<0>{}]; + // const auto offset = offsets[idx_m]; - // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", - // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + // // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); - // read from distributed tensor - // vector_type_t vec; - vector_t vec_value; + // // read from distributed tensor + // // vector_type_t vec; + // vector_t vec_value; - static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { - constexpr auto idx_ys = generate_tuple( - [&](auto jj) { - return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) - : idx_ys_start[jj]; - }, - number{}); + // static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + // constexpr auto idx_ys = generate_tuple( + // [&](auto jj) { + // return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + // : idx_ys_start[jj]; + // }, + // number{}); - constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - Traits::PackedSize; - // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); - vec_value.template get_as()(j / Traits::PackedSize) = - dstr_tensor.get_thread_buffer().template at(); - }); + // constexpr index_t d = + // tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + // Traits::PackedSize; + // // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + // vec_value.template get_as()(j / Traits::PackedSize) = + // dstr_tensor.get_thread_buffer().template at(); + // }); - // const vector_t vec_value = vec.template get_as().template at<0>(); + // // const vector_t vec_value = vec.template get_as().template at<0>(); - // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( - bottom_tensor_thread_coord, - offset, - vec_value, - bool_constant{}); - // printf("coord_offset:%d, scatter_offset:%d \n", - // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate - if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) - { - constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + // // write into bottom tensor + // get_bottom_tensor_view().template set_vectorized_elements( + // bottom_tensor_thread_coord, + // offset, + // vec_value, + // bool_constant{}); + // // printf("coord_offset:%d, scatter_offset:%d \n", + // // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + // if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + // { + // constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number{}); + // constexpr auto forward_step_scatter = generate_tuple( + // [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number{}); - constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), - forward_step_scatter); + // constexpr auto idx_diff_ps_ys = container_concat( + // generate_tuple([&](auto) { return number<0>{}; }, number{}), + // forward_step_scatter); - move_window_adaptor_and_bottom_tensor_thread_coordinate( - window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - } - }); - }); - } + // move_window_adaptor_and_bottom_tensor_thread_coordinate( + // window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + // } + // }); + // }); + // } template CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, diff --git a/include/ck_tile/core/tensor/tile_window_paged.hpp b/include/ck_tile/core/tensor/tile_window_paged.hpp new file mode 100644 index 0000000000..5cf77a1ecd --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window_paged.hpp @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/** + * @brief This class provides tile (windowed) view and access to the device memory. + * + * @note This tile window does not support single issue you need to use tile_window_linear + * structure for this purpose + * + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam NumCoord TBD + */ +template +struct page_tile_with_static_distribution +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static_assert(NumCoord == 1); + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + page_tile_with_static_distribution:: + get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr page_tile_with_static_distribution() = default; + + CK_TILE_DEVICE constexpr page_tile_with_static_distribution( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution, + const PageIdxArray& page_idx) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + page_idx_{page_idx}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + tuple(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]); + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_m = idx_ys_start[number<0>{}]; + const auto page_offset = page_idx_[idx_m]; + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, page_offset, bool_constant{}); +#if 1 + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + printf("off %d\n", page_idx_[I0]); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + // window_origin_ + + // tuple(0, window_adaptor_thread_coord.get_bottom_index()[1]); + + // auto bottom_tensor_thread_coord = make_tensor_coordinate( + // bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_m = idx_ys_start[number<0>{}]; + const auto page_offset = page_idx_[idx_m]; + + // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + + // read from distributed tensor + // vector_type_t vec; + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + vec_value.template get_as()(j / Traits::PackedSize) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // const vector_t vec_value = vec.template get_as().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + // printf("coord_offset:%d, scatter_offset:%d \n", + // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step); + }); + } + + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) + { + // window_origin_ += step; + + // static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + // pre_computed_coords_(iCoord)(I1), + // step); + // }); + page_idx_ = new_idx; + } +// CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) +// { +// window_origin_ = new_window_origin; + +// #if 0 // debug +// // TODO: this use more register for FA, but less register for GEMM +// // need investigation +// // only support warp-tile and block-tile +// static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + +// WindowAdaptorCoord window_adaptor_thread_coord_tmp; + +// if constexpr(NDimP == 1) +// { +// window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( +// tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); +// } +// else if constexpr(NDimP == 2) +// { +// window_adaptor_thread_coord_tmp = +// make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), +// AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); +// } +// #else +// // TODO: this use less register for FA, but more register for GEMM +// // need investigation +// const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( +// tile_dstr_.get_ps_ys_to_xs_adaptor(), +// container_concat(detail::get_partition_index(tile_dstr_), array{0})); +// #endif + +// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = +// window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + +// const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( +// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + +// // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up +// // future load/store() calls (might allocate more registers) +// using Traits = load_store_traits; +// using SFC_Ys = typename Traits::SFC_Ys; + +// static_for<0, NumCoord, 1>{}([&](auto iCoord) { +// auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; +// auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + +// constexpr auto idx_diff_ys = +// SFC_Ys::get_step_between(number<0>{}, number{}); + +// constexpr auto idx_diff_ps_ys = container_concat( +// generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + +// move_window_adaptor_and_bottom_tensor_thread_coordinate( +// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + +// pre_computed_coords_(iCoord) = +// make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); +// }); +// } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + PageIdxArray page_idx_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_window_paged(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + number = {}) +{ + return page_tile_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx}; +} + +// this version can't be called in a constexpr context +// template +// CK_TILE_DEVICE auto +// make_tile_window_raw(const TensorView_& tensor_view, +// const WindowLengths_& window_lengths, +// const multi_index& origin, +// const StaticTileDistribution_& tile_distribution, +// number = {}) +// { +// auto w = page_tile_with_static_distribution, +// remove_cvref_t, +// remove_cvref_t, +// NumCoord>{ +// tensor_view, window_lengths, origin, tile_distribution}; +// w.init_raw(); +// return w; +// } + +// template +// CK_TILE_DEVICE void move_tile_window( +// page_tile_with_static_distribution& window, +// const typename page_tile_with_static_distribution::BottomTensorIndex& step) +// { +// window.move(step); +// } + + +template +CK_TILE_DEVICE constexpr auto +make_tile_window_paged(const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx) +{ + return make_tile_window_paged(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx); +} + +template +CK_TILE_DEVICE constexpr auto +make_tile_window_paged(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx) +{ + return make_tile_window_paged(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx); +} + +// template +// CK_TILE_DEVICE constexpr auto +// make_tile_window_raw(const tile_window_with_static_lengths& tile_window, +// const StaticTileDistribution& tile_distribution) +// { +// auto w = make_tile_window_paged(tile_window.get_bottom_tensor_view(), +// tile_window.get_window_lengths(), +// tile_window.get_window_origin(), +// tile_distribution); +// w.init_raw(); +// return w; +// } + + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp index 59ff88c4f2..d98a1e899a 100644 --- a/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp +++ b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp @@ -97,7 +97,7 @@ __global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, int idx = blockIdx.x * blockDim.x + threadIdx.x; int row = idx / N; // Compute row index int col = idx % N; // Compute column index - + (void)Num_tokens; // assert(p_sorted_expert_ids_ != nullptr); // assert(TopK == 1); // assert(Num_tokens == 128); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 6ed0de57ab..fb3a9e82bb 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -6,7 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" - +#include "ck_tile/core/tensor/tile_window_paged.hpp" namespace ck_tile { template