let more integral_constant->constant, and formating

This commit is contained in:
carlushuang
2024-03-13 18:33:10 +00:00
parent b1dbf64c91
commit 616932068d
9 changed files with 36 additions and 30 deletions

View File

@@ -55,3 +55,4 @@
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -1562,25 +1562,25 @@ CK_TILE_HOST_DEVICE constexpr auto
make_pad_transform(const LowLength& low_length,
const LeftPad& left_pad,
const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto make_left_pad_transform(
const LowLength& low_length,
const LeftPadLength& left_pad_,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
CK_TILE_HOST_DEVICE constexpr auto
make_left_pad_transform(const LowLength& low_length,
const LeftPadLength& left_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(
const LowLength& low_length,
const RightPadLength& right_pad_,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
CK_TILE_HOST_DEVICE constexpr auto
make_right_pad_transform(const LowLength& low_length,
const RightPadLength& right_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
}
@@ -1615,9 +1615,9 @@ CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_le
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(
const UpLengths& up_lengths,
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
CK_TILE_HOST_DEVICE constexpr auto
make_unmerge_transform(const UpLengths& up_lengths,
bool_constant<Use24BitIntegerCalculation> = bool_constant<false>{})
{
return unmerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}

View File

@@ -60,7 +60,7 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto get()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, integral_constant<value_type, Is>...>{}>{};
return number<impl::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
@@ -81,7 +81,7 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto at()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, integral_constant<value_type, Is>...>{}>{};
return number<impl::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
@@ -384,7 +384,7 @@ template <index_t... Ids, index_t... Ns>
struct seq_reverse<sequence<Ids...>, Ns...>
{
template <index_t I>
using element = impl::at_index_t<I, integral_constant<index_t, Ns>...>;
using element = impl::at_index_t<I, constant<Ns>...>;
using type = sequence<element<(sizeof...(Ns) - 1 - Ids)>::value...>;
};
} // namespace impl

View File

@@ -274,6 +274,17 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>

View File

@@ -52,8 +52,7 @@ struct magic_division32_bit_range
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t
@@ -116,25 +115,15 @@ struct magic_division16_bit_range
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<uint32_t, Divisor>)
template <auto Divisor>
CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant<Divisor>)
{
constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[number<0>{}];
constexpr uint32_t shift = tmp[number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_magic_numbers(integral_constant<int32_t, Divisor>)
{
return calculate_magic_numbers(integral_constant<uint32_t, Divisor>{});
return make_tuple(constant<multiplier>{}, constant<shift>{});
}
// magic division for uint32_t

View File

@@ -20,3 +20,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"

View File

@@ -4,3 +4,4 @@
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -4,3 +4,5 @@
#pragma once
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -18,3 +18,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"