mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
upgrade from clang-format-12 to clang-format-18 (#2568)
* upgrade to clang-format-18 * update to clang-format-18 in pre-commit-config
This commit is contained in:
@@ -1259,7 +1259,7 @@ struct slice : public base_transform<1, 1>
|
||||
|
||||
printf("}");
|
||||
} // namespace ck
|
||||
}; // namespace ck
|
||||
}; // namespace ck
|
||||
|
||||
/*
|
||||
* \brief lower_idx = upper_idx % modulus.
|
||||
|
||||
@@ -100,10 +100,8 @@ struct space_filling_curve
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
|
||||
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
|
||||
@@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
@@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop>
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n"
|
||||
"v_cmpx_le_u32 exec, 1, %4\n"
|
||||
@@ -624,7 +624,7 @@ struct buffer_store_if<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = fp32x4_t;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -681,7 +681,7 @@ struct buffer_store_if<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -709,7 +709,7 @@ struct buffer_store_if<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 2);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = short;
|
||||
using mbuf_t = short;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
@@ -737,7 +737,7 @@ struct buffer_store_if<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
|
||||
((cnt)&0b1111) | (((cnt)&0b110000) << 10))
|
||||
((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
|
||||
#define CK_TILE_EXPCNT(cnt) \
|
||||
([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
|
||||
#define CK_TILE_LGKMCNT(cnt) \
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename TData, index_t NSize>
|
||||
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
array<TData, NSize + 1> r;
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
|
||||
r[number<NSize>{}] = x;
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1236,9 +1236,8 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
constexpr auto
|
||||
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
|
||||
@@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
static constexpr int bias = 7; // OCP
|
||||
#else
|
||||
static constexpr int bias = 8; // FNUZ
|
||||
static constexpr int bias = 8; // FNUZ
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
|
||||
@@ -31,8 +31,8 @@ struct scales
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
operator()(const Right& rhs) const -> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
@@ -43,13 +43,13 @@ struct scales
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
__host__ __device__ scales(Scale) -> scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
@@ -59,21 +59,21 @@ template <>
|
||||
struct plus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs + rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs + rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ plus()->plus<void, void>;
|
||||
__host__ __device__ plus() -> plus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct minus
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
@@ -83,21 +83,21 @@ template <>
|
||||
struct minus<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs - rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs - rhs)
|
||||
{
|
||||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ minus()->minus<void, void>;
|
||||
__host__ __device__ minus() -> minus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct multiplies
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
@@ -107,15 +107,15 @@ template <>
|
||||
struct multiplies<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs * rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ multiplies()->multiplies<void, void>;
|
||||
__host__ __device__ multiplies() -> multiplies<void, void>;
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
@@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
@@ -338,15 +338,15 @@ template <>
|
||||
struct equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs == rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs == rhs)
|
||||
{
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ equal()->equal<void, void>;
|
||||
__host__ __device__ equal() -> equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct equal<float, float>
|
||||
@@ -369,8 +369,8 @@ struct equal<double, double>
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
@@ -380,21 +380,21 @@ template <>
|
||||
struct less<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs < rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs < rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less()->less<void, void>;
|
||||
__host__ __device__ less() -> less<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less_equal
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
@@ -404,15 +404,15 @@ template <>
|
||||
struct less_equal<void, void>
|
||||
{
|
||||
template <typename Left, typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
|
||||
-> decltype(lhs <= rhs)
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs,
|
||||
const Right& rhs) const -> decltype(lhs <= rhs)
|
||||
{
|
||||
return lhs <= rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less_equal()->less_equal<void, void>;
|
||||
__host__ __device__ less_equal() -> less_equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct less_equal<float, float>
|
||||
|
||||
@@ -117,8 +117,8 @@ struct DefaultTranspose
|
||||
struct ValidationTraitsImpl
|
||||
{
|
||||
using QuadEncoding = std::conditional_t<ReverseDirection,
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
QuadOutputEncoding<LaneGroupSize>,
|
||||
QuadInputEncoding<LaneGroupSize>>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto input_hs = InDstrEncode::hs_lengthss_;
|
||||
@@ -396,9 +396,9 @@ template <
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
|
||||
@@ -303,6 +303,6 @@ struct tile_sweeper
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -81,7 +81,7 @@ struct tensor_adaptor
|
||||
|
||||
template <index_t IDimHidden>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
get_transform_and_its_upper_dimension(number<IDimHidden>)
|
||||
{
|
||||
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
|
||||
// saved in transformation
|
||||
@@ -119,13 +119,13 @@ struct tensor_adaptor
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
@@ -461,7 +461,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
@@ -470,8 +470,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
number<num_new_transform>{});
|
||||
|
||||
// new top dimension's hidden ids
|
||||
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
constexpr auto unordered_new_top_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_top_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
|
||||
@@ -595,8 +595,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
@@ -619,8 +618,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
@@ -643,8 +641,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
@@ -653,8 +650,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
|
||||
@@ -202,7 +202,7 @@ struct tile_distribution
|
||||
// FIXME: it's hacky to get Y index from Distributed-Index
|
||||
template <typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
get_y_indices_from_distributed_indices(DistributedIndices)
|
||||
{
|
||||
constexpr auto ys_idx_arr = [] {
|
||||
array<index_t, NDimY> ys_idx;
|
||||
@@ -266,7 +266,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t
|
||||
// this returns a constexpr encoding of tile_distribution
|
||||
template <typename StaticTileDistributionEncoding_>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
|
||||
{
|
||||
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
|
||||
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
|
||||
@@ -614,8 +614,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
constexpr auto src_y_maps = src_y_info[number<1>{}];
|
||||
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
|
||||
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
|
||||
{
|
||||
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
|
||||
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
|
||||
auto y_slice_lengths = Encoding::detail::ys_lengths_;
|
||||
constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
|
||||
@@ -685,8 +684,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
|
||||
|
||||
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
|
||||
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
|
||||
|
||||
@@ -327,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
template <typename DstType, typename SrcTensor>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcTensor::DataType,
|
||||
float> &&
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
|
||||
|
||||
@@ -74,8 +74,9 @@ struct tile_window_linear
|
||||
static constexpr auto get_num_non_linear_access()
|
||||
{
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(typename Base::TileDstr{}
|
||||
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto non_linear = [&]() {
|
||||
index_t cnt = 1;
|
||||
@@ -109,8 +110,9 @@ struct tile_window_linear
|
||||
static constexpr auto get_non_linear_access_map()
|
||||
{
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(typename Base::TileDstr{}
|
||||
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto non_linear_map = [&]() {
|
||||
array<index_t, Base::Traits::NumAccess> m_{0};
|
||||
index_t cumulative_len_ = 1;
|
||||
@@ -244,8 +246,9 @@ struct tile_window_linear
|
||||
{
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(typename Base::TileDstr{}
|
||||
.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto modified_idx_ys = generate_tuple(
|
||||
[&](auto i_dim_y) {
|
||||
|
||||
@@ -48,7 +48,7 @@ struct str_literal
|
||||
|
||||
template <size_t... Idx>
|
||||
constexpr std::tuple<std::integral_constant<size_t, Idx>...>
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
makeTuple(std::index_sequence<Idx...>) noexcept
|
||||
{
|
||||
return {};
|
||||
}
|
||||
@@ -113,8 +113,8 @@ struct CK_PRINTF<ConvertTo,
|
||||
std::integer_sequence<index_t, Is...>) const
|
||||
{
|
||||
using FMT1 = std::conditional_t<sizeof...(FMTChars) == 0,
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
decltype(default_format<Y>()),
|
||||
str_literal<FMTChars...>>;
|
||||
constexpr auto fmt_v = FMT1::template duplicate_n<N>(make_str_literal(" "));
|
||||
constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix();
|
||||
|
||||
|
||||
@@ -58,8 +58,8 @@ struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
|
||||
|
||||
struct nonesuch
|
||||
{
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
~nonesuch() = delete;
|
||||
nonesuch(nonesuch const&) = delete;
|
||||
void operator=(nonesuch const&) = delete;
|
||||
};
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ struct composes<F>
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
@@ -57,8 +57,8 @@ struct saturates
|
||||
// NOTE: this function does not return SaturateType value
|
||||
// it is user's responsiblity to do further cast or not
|
||||
template <typename AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
operator()(const AccType& a_) const -> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
{
|
||||
return clamp(a_,
|
||||
type_convert<AccType>(numeric<SaturateType>::lowest()),
|
||||
|
||||
Reference in New Issue
Block a user