mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Merge commit 'a2f01141aadedc9bfcd5edad75dbaf67d3d5990a' into develop
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
@@ -39,6 +40,7 @@
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
||||
|
||||
@@ -2784,6 +2784,40 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
|
||||
return xdllevel_dstr_encoding{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -133,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the generic address space, we do not support the transpose instruction in the buffer view.
|
||||
Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -359,6 +382,28 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the global memory address space, we do not support the transpose instruction in the buffer
|
||||
view. Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -852,6 +897,43 @@ struct buffer_view<address_space_enum::lds,
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
constexpr address_space_enum addr_space = get_address_space();
|
||||
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
|
||||
p_data_ + i + linear_offset);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
return X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
|
||||
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
{
|
||||
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
|
||||
|
||||
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
|
||||
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
|
||||
|
||||
static constexpr bool value =
|
||||
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
|
||||
};
|
||||
|
||||
template <index_t... Xs>
|
||||
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Suffix, typename Sequence>
|
||||
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
|
||||
|
||||
} // namespace util
|
||||
|
||||
// Default policy: Retains original 2D transpose behavior
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
struct Quad16
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<4, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
struct Quad8
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<2, 8>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::InputEncoding,
|
||||
typename Quad8::InputEncoding>;
|
||||
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::OutputEncoding,
|
||||
typename Quad8::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
// Programmable: Element grouping function
|
||||
static constexpr auto group_func = [](auto idx) {
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
|
||||
decltype(input_hs_lengthss.template get<0>())>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
|
||||
decltype(input_hs_lengthss.template get<1>())>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
|
||||
static constexpr index_t ndimp_inner =
|
||||
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
|
||||
input_hs_lengthss[number<1>{}].size() - 2) &&
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 1);
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_to_rhs_major.back() == 2) &&
|
||||
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
|
||||
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
|
||||
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 2);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
|
||||
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
|
||||
|
||||
static constexpr bool distr_encoding_valid = Validator::value;
|
||||
};
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistribution_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
struct OutputTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
|
||||
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
|
||||
|
||||
// for transpose load
|
||||
// append the reversed quad output hs lengths to the input hs lengthss after removing
|
||||
// the quad_input_hs_lengthss
|
||||
// then reverse the whole sequence to get the dst_out_hs_lengthss
|
||||
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
|
||||
|
||||
static constexpr auto full_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
return input_hs_lengthss[i]
|
||||
.extract(typename arithmetic_sequence_gen<0,
|
||||
input_hs_lengthss[i].size() -
|
||||
quad_input_hs_lengthss[i].size(),
|
||||
1>::type{})
|
||||
.push_back(reversed_quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
|
||||
// sequence
|
||||
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.push_back(number<2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_major[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto minor_last_index =
|
||||
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
|
||||
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_minor[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
|
||||
number<modified_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto modified_input_ys_to_rhs_major =
|
||||
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
|
||||
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
|
||||
number<modified_input_ys_to_rhs_major.size()>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor =
|
||||
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
|
||||
|
||||
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
index_t kLeadNumWarps,
|
||||
index_t kSecondNumWarps>
|
||||
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
{
|
||||
constexpr auto block_outer_dst_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
|
||||
|
||||
return blk_distr_encode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<TileDistribution_,
|
||||
typename BottomTensorView_::DataType>::OutDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"the element space size is not the same!");
|
||||
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
out_tensor.get_thread_buffer().template set_as<DataVec>(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -251,6 +251,33 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
|
||||
@@ -407,6 +407,82 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
.template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto orig_idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
@@ -415,7 +491,6 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
|
||||
@@ -613,6 +613,60 @@ struct tile_window_linear
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose_linear<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
};
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
|
||||
Reference in New Issue
Block a user