Merge commit 'a2f01141aadedc9bfcd5edad75dbaf67d3d5990a' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-18 09:15:28 +00:00
parent 60a1cf775c
commit 2e67382832
17 changed files with 1523 additions and 1 deletions

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/utility.hpp"
@@ -39,6 +40,7 @@
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/shuffle_tile.hpp"

View File

@@ -2784,6 +2784,40 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
#endif
}
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
}
} // namespace ck_tile
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
// this generate wave level tile distribution
template <typename T, typename = void>
struct LaneGroupTransposeTraits;
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
{
// before transpose, 4x16
static constexpr index_t ksecondDim = 4;
static constexpr index_t kleadDim = 16;
// after transpose, 16x4
static constexpr index_t ksecondDimT = 16;
static constexpr index_t kleadDimT = 4;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
};
template <typename T>
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
{
static constexpr index_t ksecondDim = 8;
static constexpr index_t kleadDim = 16;
static constexpr index_t ksecondDimT = 16;
static constexpr index_t kleadDimT = 8;
template <index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
using TileDistribution =
tile_distribution_encoding<sequence<>,
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
tuple<sequence<1, 2, 1, 2>>,
tuple<sequence<0, 0, 2, 2>>,
sequence<2, 1, 2>,
sequence<1, 1, 3>>;
};
/*
* @brief This function is used to generate the transposed distribution encoding
* for the given data type and distribution dimensions.
*
* @tparam T The data type of the elements in the tensor.
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
* consecutive.
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
* consecutive.
*/
template <typename T,
index_t kOuterDistDim0,
index_t kOuterDistDim1,
index_t kInnerDistDim0,
index_t kInnerDistDim1>
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
{
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
return xdllevel_dstr_encoding{};
}
} // namespace ck_tile

View File

@@ -18,6 +18,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/ignore.hpp"
namespace ck_tile {
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
}
}
/*
In the generic address space, we do not support the transpose instruction in the buffer view.
Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -359,6 +382,28 @@ struct buffer_view<address_space_enum::global,
}
}
/*
In the global memory address space, we do not support the transpose instruction in the buffer
view. Will report compilation error when developer wants to use it.
*/
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
static_assert(false, "Error: transpose load not supported in global memory space.");
ignore = i;
ignore = linear_offset;
ignore = is_valid_element;
return;
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
@@ -852,6 +897,43 @@ struct buffer_view<address_space_enum::lds,
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
}
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if(is_valid_element)
{
constexpr address_space_enum addr_space = get_address_space();
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
p_data_ + i + linear_offset);
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,

View File

@@ -0,0 +1,362 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
namespace util {
template <typename Suffix, typename Sequence>
struct is_sequence_suffix
{
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
static constexpr bool value =
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
};
template <index_t... Xs>
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
{
static constexpr bool value = true;
};
template <typename Suffix, typename Sequence>
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
} // namespace util
// Default policy: Retains original 2D transpose behavior
template <typename DataType>
struct DefaultTranspose
{
struct Quad16
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<4, 4>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
struct Quad8
{
using InputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<2, 8>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
using OutputEncoding = tile_distribution_encoding<sequence<>,
tuple<sequence<16>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
// Select based on data size
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::InputEncoding,
typename Quad8::InputEncoding>;
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16::OutputEncoding,
typename Quad8::OutputEncoding>;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};
// Programmable: Element grouping function
static constexpr auto group_func = [](auto idx) {
return idx; // Identity mapping
};
template <typename InDstrEncode>
struct ValidationTraits
{
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
// 1. Must be 2D tensor
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
// 2. Quad pattern must be suffix of input pattern
static constexpr bool suffix_valid_dim0 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
decltype(input_hs_lengthss.template get<0>())>;
static constexpr bool suffix_valid_dim1 =
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
decltype(input_hs_lengthss.template get<1>())>;
// 3. PS→RHS mapping constraints
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
static constexpr index_t ndimp_inner =
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
static constexpr bool ps_mapping_valid =
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
input_hs_lengthss[number<1>{}].size() - 2) &&
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
input_hs_lengthss[number<0>{}].size() - 1);
// 4. YS→RHS mapping constraints
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr bool ys_mapping_valid =
(input_ys_to_rhs_major.back() == 2) &&
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
input_hs_lengthss[number<0>{}].size() - 2);
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
ps_mapping_valid && ys_mapping_valid;
};
};
template <typename TileDistribution_, typename DataType_, typename Policy>
struct TransposeTileDistrChecker
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
static constexpr bool distr_encoding_valid = Validator::value;
};
// this is used to generate the transposed output tile distribution encoding
// based on the input tile distribution encoding
template <typename TileDistribution_,
typename DataType_,
typename Policy = DefaultTranspose<DataType_>>
struct OutputTileDistributionTraits
{
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
// for transpose load
// append the reversed quad output hs lengths to the input hs lengthss after removing
// the quad_input_hs_lengthss
// then reverse the whole sequence to get the dst_out_hs_lengthss
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
static constexpr auto full_out_hs_lengthss = generate_tuple(
[](auto i) {
return input_hs_lengthss[i]
.extract(typename arithmetic_sequence_gen<0,
input_hs_lengthss[i].size() -
quad_input_hs_lengthss[i].size(),
1>::type{})
.push_back(reversed_quad_output_hs_lengthss[i]);
},
number<InDstrEncode::NDimX>{});
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
// sequence
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
[](auto i) {
if constexpr(i == input_ps_to_rhss_major.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_major[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_major.push_back(number<2>{});
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_major[i];
}
},
number<input_ps_to_rhss_major.size()>{});
static constexpr auto minor_last_index =
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
[](auto i) {
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
{
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
}
else
{
// For all other sequences, keep them unchanged
return input_ps_to_rhss_minor[i];
}
},
number<input_ps_to_rhss_minor.size()>{});
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
static constexpr auto swap_one_and_two = [](const index_t idx) {
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
};
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
number<modified_ps_to_rhss_major.size()>{});
static constexpr auto modified_input_ys_to_rhs_major =
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
number<modified_input_ys_to_rhs_major.size()>{});
static constexpr auto dst_ys_to_rhs_minor =
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
};
template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
index_t kLeadNumWarps,
index_t kSecondNumWarps>
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
{
constexpr auto block_outer_dst_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto blk_distr_encode =
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
return blk_distr_encode;
}
/**
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
*
* This function is intended for use with statically distributed tensor tiles, where the input
* and output tile distributions differ due to the transpose operation. It ensures that the
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param tile_window The tile window with static distribution to load and transpose.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
using OutTileDstrEncode =
typename OutputTileDistributionTraits<TileDistribution_,
typename BottomTensorView_::DataType>::OutDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
static_assert(y_in_element_space_size == y_out_element_space_size,
"the element space size is not the same!");
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
"the vector length is not the same!");
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
constexpr index_t num_of_access =
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
out_tensor.get_thread_buffer().template set_as<DataVec>(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
return out_tensor;
}
} // namespace ck_tile

View File

@@ -251,6 +251,33 @@ struct tensor_view
bool_constant<pre_nop>{});
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
{
return buf_.template transpose_get<X>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_transpose_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element // flag
) const
{
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,

View File

@@ -407,6 +407,82 @@ struct tile_window_with_static_distribution
});
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose() const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
this->template load_transpose<Policy>(
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename Policy,
typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto group_func = Policy::group_func;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view()
.template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0);
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto orig_idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
constexpr index_t linear_distributed_index =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
vec_value.template get_as<typename Base::DataType>()[j];
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
typename Base::TileDstr>& dstr_tensor,
@@ -415,7 +491,6 @@ struct tile_window_with_static_distribution
{
using Traits = typename Base::Traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;

View File

@@ -613,6 +613,60 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose() const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
this->template load_transpose_linear<Policy>(
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename Policy,
typename DistributedTensor,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto group_func = Policy::group_func;
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr index_t linear_distributed_index =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
vec_value.template get_as<typename Base::DataType>()[j];
});
};
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
typename Base::TileDstr>& dstr_tensor,