diff --git a/CMakeLists.txt b/CMakeLists.txt index bdeba33eac..2a36665552 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX) include(CTest) +find_package(Python3 3.7 COMPONENTS Interpreter REQUIRED) + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") if (DTYPES) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index d3b229daae..9536df9863 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -14,7 +14,7 @@ add_custom_command( --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) -set(EXAMPLE_FMHA_FWD "example_fmha_fwd") +set(EXAMPLE_FMHA_FWD "ck_tile_example_fmha_fwd") add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 34d7421bbd..b3be008f09 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -5,7 +5,7 @@ import argparse import itertools from pathlib import Path -from typing import List, Optional, tuple +from typing import List, Optional from dataclasses import dataclass import copy diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 2e26fcb897..107f1f61d0 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -7,7 +7,7 @@ #include #include "ck_tile/core.hpp" -#include "ck_tile/fmha.hpp" +#include "ck_tile/ops/fmha.hpp" enum class mask_enum { diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000..d2b086e043 --- /dev/null +++ b/example/ck_tile/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(01_fmha) diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index b8efe049c1..7fa7eb9590 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -224,7 +224,7 @@ struct pad : public base_transform<1, 1> CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { - return ck_tile::ck_tile::is_known_at_compile_time::value && + return ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value && ck_tile::is_known_at_compile_time::value; } @@ -577,7 +577,7 @@ struct merge_v2_magic_division : public base_transform using UpperIndex = multi_index<1>; using UpLengths = - decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{}))); + decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{}))); using LowLengthsMagicDivisor = decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, @@ -597,7 +597,7 @@ struct merge_v2_magic_division : public base_transform low_lengths_magic_divisor_{generate_tuple( [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); }, number{})}, - up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, I1))} + up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))} { static_assert(LowerIndex::size() == NDimLow, "wrong!"); } @@ -722,10 +722,10 @@ struct merge_v3_division_mod : public base_transform using UpperIndex = multi_index<1>; using LowLengthsScan = - decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, number<1>{})); + decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{})); using UpLengths = - decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, number<1>{}))); + decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{}))); LowLengths low_lengths_; LowLengthsScan low_lengths_scan_; @@ -736,8 +736,8 @@ struct merge_v3_division_mod : public base_transform CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths) : low_lengths_{low_lengths}, low_lengths_scan_{ - container_reverse_exclusive_scan(low_lengths, math::multiplies{}, number<1>{})}, - up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, number<1>{}))} + container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))} { static_assert(LowerIndex::size() == NDimLow, "wrong!"); } @@ -855,7 +855,7 @@ struct unmerge : public base_transform<1, UpLengths::size()> using UpperIndex = multi_index; using UpLengthsScan = - decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, number<1>{})); + decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{})); UpLengths up_lengths_; UpLengthsScan up_lengths_scan_; @@ -864,8 +864,7 @@ struct unmerge : public base_transform<1, UpLengths::size()> CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths) : up_lengths_{up_lengths}, - up_lengths_scan_{ - container_reverse_exclusive_scan(up_lengths, math::multiplies{}, number<1>{})} + up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})} { } @@ -944,7 +943,7 @@ struct unmerge : public base_transform<1, UpLengths::size()> { if(low_vector_lengths[0] != -1) { - up_vector_lengths(NDimUp - 1) = math::gcd(low_vector_lengths[0], up_length_last); + up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last); } } @@ -979,7 +978,7 @@ struct freeze : public base_transform<1, 0> CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} - CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return Tuple<>{}; } + CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; } template CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, @@ -1428,7 +1427,7 @@ struct xor_t : public base_transform<2, 2> { if(low_vector_lengths[1] != -1) { - up_vector_lengths(1) = math::gcd(low_vector_lengths[1], math::abs(right_shift_)); + up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_)); } } @@ -1546,35 +1545,35 @@ struct offset : public base_transform<1, 1> template CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length) { - return PassThrough{low_length}; + return pass_through{low_length}; } -template +template CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength& low_length, - const left_pad& left_pad, - const right_pad& right_pad, + const LeftPad& left_pad, + const RightPad& right_pad, integral_constant = integral_constant{}) { - return pad{low_length, left_pad, right_pad}; + return pad{low_length, left_pad, right_pad}; } template CK_TILE_HOST_DEVICE constexpr auto make_left_pad_transform( const LowLength& low_length, - const LeftPadLength& left_pad, + const LeftPadLength& left_pad_, integral_constant = integral_constant{}) { - return left_pad{low_length, left_pad}; + return left_pad{low_length, left_pad_}; } template CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform( const LowLength& low_length, - const RightPadLength& right_pad, + const RightPadLength& right_pad_, integral_constant = integral_constant{}) { - return right_pad{low_length, right_pad}; + return right_pad{low_length, right_pad_}; } template {}); #endif constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + statically_indexed_array forward_sweep_; forward_sweep_(I0) = true; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 6d922dc973..cfba73f74d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1754,7 +1754,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = [&]() { if constexpr(oob_conditional_check) return src_thread_element_valid ? 0 : 0x80000000; @@ -1876,7 +1876,7 @@ amd_buffer_store(const typename vector_type_maker::type::type src_thread_d using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = [&]() { if constexpr(oob_conditional_check) return dst_thread_element_valid ? 0 : 0x80000000; @@ -1951,7 +1951,7 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_add_impl( @@ -1986,7 +1986,7 @@ amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thr using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_max_impl( @@ -2028,7 +2028,7 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; -#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM +#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index d8e89b9190..55cca88c8c 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,6 +3,11 @@ #pragma once +#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif + #ifdef __HIPCC__ #define CK_TILE_HOST __host__ #define CK_TILE_DEVICE __device__ @@ -54,3 +59,64 @@ #define CK_TILE_MAX_THREAD_PER_BLOCK 256 #define CK_TILE_MIN_BLOCK_PER_CU 2 + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM +#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_LOAD +#define CK_TILE_USE_AMD_BUFFER_LOAD 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_STORE +#define CK_TILE_USE_AMD_BUFFER_STORE 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 +#endif + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS +#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 +#endif + +#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE +#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 +#endif + +#ifndef CK_TILE_DEBUG_LOG +#define CK_TILE_DEBUG_LOG 0 +#endif diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 67d0379afc..fd910abac9 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -6,6 +6,8 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/functional.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index 25e065c3c1..87b180cafc 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -75,11 +75,11 @@ struct map CK_TILE_HOST_DEVICE void clear() { size_ = 0; } - CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& key) const + CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const { for(index_t i = 0; i < size(); i++) { - if(impl_[i].template at<0>() == key) + if(impl_[i].template at<0>() == k) { return i; } @@ -88,39 +88,39 @@ struct map return size_; } - CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& key) const + CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const { - return const_iterator{impl_, find_position(key)}; + return const_iterator{impl_, find_position(k)}; } - CK_TILE_HOST_DEVICE constexpr iterator find(const key& key) + CK_TILE_HOST_DEVICE constexpr iterator find(const key& k) { - return iterator{impl_, find_position(key)}; + return iterator{impl_, find_position(k)}; } - CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& key) const + CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const { - const auto it = find(key); + const auto it = find(k); // FIXME - assert(it.pos_ < size()); + // assert(it.pos_ < size()); return impl_[it.pos_].template at<1>(); } - CK_TILE_HOST_DEVICE constexpr data& operator()(const key& key) + CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k) { - auto it = find(key); + auto it = find(k); // if entry not found if(it.pos_ == size()) { - impl_(it.pos_).template at<0>() = key; + impl_(it.pos_).template at<0>() = k; size_++; } // FIXME - assert(size_ <= max_size); + // assert(size_ <= max_size); return impl_(it.pos_).template at<1>(); } @@ -146,12 +146,12 @@ struct map // printf("impl_: ["); // - for(const auto& [key, data] : *this) + for(const auto& [k, d] : *this) { printf("{key: "); - print(key); + print(k); printf(", data: "); - print(data); + print(d); printf("}, "); } // diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index dc1971a330..15313b1b65 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -4,11 +4,12 @@ #pragma once #include "ck_tile/core/config.hpp" -#include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/to_sequence.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/functional.hpp" namespace ck_tile { @@ -55,14 +56,6 @@ struct sequence CK_TILE_HOST_DEVICE static constexpr index_t size() { return sizeof...(Is); } CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }; - CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I) - { - // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 - static_assert(I < size(), "wrong! I too large"); - const index_t mData[mSize + 1] = {Is..., 0}; - return mData[I]; - } - template CK_TILE_HOST_DEVICE static constexpr auto get() { @@ -81,7 +74,7 @@ struct sequence { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 static_assert(I < size(), "wrong! I too large"); - const index_t mData[mSize + 1] = {Is..., 0}; + const index_t mData[size() + 1] = {Is..., 0}; return mData[I]; } @@ -298,7 +291,7 @@ struct arithmetic_sequence_gen static constexpr bool kHasContent = (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); - using type = typename conditional::type; + using type = typename std::conditional::type; }; template @@ -403,8 +396,8 @@ struct sequence_reverse> { }; -template -using sequence_reverse_t = typename sequence_reverse::type; +// template +// using sequence_reverse_t = typename sequence_reverse::type; #if 1 template @@ -449,16 +442,15 @@ struct sequence_sort_impl using new_merged_values = decltype(MergedValues::push_back(number{})); using new_merged_ids = decltype(MergedIds::push_back(number{})); - using new_left_values = - typename conditional::type; + using new_left_values = typename std:: + conditional::type; using new_left_ids = - typename conditional::type; + typename std::conditional::type; - using new_right_values = typename conditional::type; + using new_right_values = typename std:: + conditional::type; using new_right_ids = - typename conditional::type; + typename std::conditional::type; using merge = sorted_sequence_merge_impl, sequence, Compare> { static constexpr bool choose_x = Compare{}(ValueX, ValueY); - using sorted_values = - typename conditional, sequence>::type; - using sorted_ids = typename conditional, sequence>::type; + using sorted_values = typename std:: + conditional, sequence>::type; + using sorted_ids = + typename std::conditional, sequence>::type; }; template @@ -606,14 +599,15 @@ struct sequence_unique_sort using new_remain_ids = decltype(RemainIds::pop_front()); using new_uniquified_values = - typename conditional{})), - UniquifiedValues>::type; + typename std::conditional{})), + UniquifiedValues>::type; using new_uniquified_ids = - typename conditional{})), - UniquifiedIds>::type; + typename std::conditional{})), + UniquifiedIds>::type; using uniquify = sorted_sequence_uniquify_impl -struct is_valid_sequence_map : is_same::type, - typename sequence_sort>::type> +struct is_valid_sequence_map + : std::is_same::type, + typename sequence_sort>::type> { }; @@ -906,7 +901,7 @@ constexpr auto prefix_sum_sequence(Seq) { return typename sequence_exclusive_scan, typename sequence_merge>::type, - math::plus>::type{}; + plus>::type{}; } template @@ -920,9 +915,9 @@ namespace detail { template struct pick_sequence_elements_by_mask_impl { - using new_work_seq = typename conditional::type; + using new_work_seq = typename std::conditional::type; using type = typename pick_sequence_elements_by_mask_impl, sequence> }; } // namespace detail +template +struct array; // declare for later use (array->seq utility) + // SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1> template -constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence) +CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence) { constexpr auto bins = sizeof...(rs); // or categories constexpr auto histogram = [&]() { diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index a93ff0f42f..f95ddb5435 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -100,8 +100,8 @@ struct tuple : impl::tuple_base, T...> { bool flag = true; - static_for<0, sizeof...(Xs), 1>{}([&flag](auto i) { - flag &= is_static_v>>; + static_for<0, sizeof...(T), 1>{}([&flag](auto i) { + flag &= is_static_v>>; }); return flag; @@ -262,11 +262,11 @@ CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element) } template -CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple& tuple) +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple& t) { if constexpr(Depth == MaxDepth) { - return tuple; + return t; } else { @@ -274,33 +274,33 @@ CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple& tuple [&](auto&&... ts) { return concat_tuple(unroll_nested_tuple(ts)...); }, - tuple); + t); } } template -CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple& tuple) +CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple& t) { return generate_tuple( [&](auto i) { - using Idx = number::size()() - i - 1>; - return tuple.at(Idx{}); + using Idx = number::size() - i - 1>; + return t.at(Idx{}); }, number::size()()>{}); } // Reduce tuple values in specific range using Function template -CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple& tuple) +CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple& t) { static_assert(Idx < End, "Wrong parameters for tuple_reduce"); if constexpr(Idx + 1 == End) { - return tuple.at(number{}); + return t.at(number{}); } else { - return f(tuple.at(number{}), tuple_reduce(f, tuple)); + return f(t.at(number{}), tuple_reduce(f, t)); } } @@ -322,7 +322,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&) template CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple&) { - return math::max(tuple_depth(Ts{})...); + return max(tuple_depth(Ts{})...); } template @@ -456,6 +456,7 @@ CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple& x, const tuple< } // namespace ck_tile +#include // WARNING: needed by compiler for C++ structured binding support only, don't use this namespace std { @@ -465,7 +466,7 @@ struct tuple_size> : std::integral_constant -struct tuple_element> : ck_tile::tuple_element> +struct tuple_element> : std::tuple_element> { }; @@ -476,7 +477,7 @@ struct tuple_size> : std::integral_constant struct tuple_element> - : ck_tile::tuple_element> + : std::tuple_element> { }; diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index d69024883e..b19aefc928 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/limits.hpp" #include #pragma once @@ -32,7 +33,7 @@ struct alignas(2) bfloat16_t raw_type data; CK_TILE_HOST_DEVICE - static bfloat16_t bit_cast(raw_type x) + static constexpr bfloat16_t bit_cast(raw_type x) { bfloat16_t y; y.data = x; diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index c9af63dd43..8ff6f06a19 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/limits.hpp" #include #include @@ -62,7 +63,7 @@ struct alignas(1) float8_e4m3_t raw_type data; CK_TILE_HOST_DEVICE - static float8_e4m3_t bit_cast(raw_type x) + static constexpr float8_e4m3_t bit_cast(raw_type x) { float8_e4m3_t y; y.data = x; @@ -70,37 +71,40 @@ struct alignas(1) float8_e4m3_t } // constructor - float8_e4m3_t() = default; + constexpr float8_e4m3_t() : data() {} // construct from float CK_TILE_HOST_DEVICE - explicit float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); } + explicit constexpr float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); } // construct from int CK_TILE_HOST_DEVICE - explicit float8_e4m3_t(const int& x) { data = float_to_fp8_raw(static_cast(x)); } + explicit constexpr float8_e4m3_t(const int& x) + { + data = float_to_fp8_raw(static_cast(x)); + } // construct from unsigned int CK_TILE_HOST_DEVICE - explicit float8_e4m3_t(const unsigned int& x) + explicit constexpr float8_e4m3_t(const unsigned int& x) { data = float_to_fp8_raw(static_cast(x)); } // cast to float CK_TILE_HOST_DEVICE - explicit operator float() const { return fp8_to_float_raw(data); } + explicit constexpr operator float() const { return fp8_to_float_raw(data); } // cast to int CK_TILE_HOST_DEVICE - explicit operator int() const { return static_cast(fp8_to_float_raw(data)); } + explicit constexpr operator int() const { return static_cast(fp8_to_float_raw(data)); } // internal access CK_TILE_HOST_DEVICE - raw_type& get() { return data; } + constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE - raw_type get() const { return data; } + constexpr raw_type get() const { return data; } }; struct alignas(1) float8_e5m2_t @@ -116,7 +120,7 @@ struct alignas(1) float8_e5m2_t raw_type data; CK_TILE_HOST_DEVICE - static float8_e5m2_t bit_cast(raw_type x) + static constexpr float8_e5m2_t bit_cast(raw_type x) { float8_e5m2_t y; y.data = x; @@ -124,37 +128,40 @@ struct alignas(1) float8_e5m2_t } // constructor - float8_e5m2_t() = default; + constexpr float8_e5m2_t() : data() {} // construct from float CK_TILE_HOST_DEVICE - explicit float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); } + explicit constexpr float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); } // construct from int CK_TILE_HOST_DEVICE - explicit float8_e5m2_t(const int& x) { data = float_to_bf8_raw(static_cast(x)); } + explicit constexpr float8_e5m2_t(const int& x) + { + data = float_to_bf8_raw(static_cast(x)); + } // construct from unsigned int CK_TILE_HOST_DEVICE - explicit float8_e5m2_t(const unsigned int& x) + explicit constexpr float8_e5m2_t(const unsigned int& x) { data = float_to_bf8_raw(static_cast(x)); } // cast to float CK_TILE_HOST_DEVICE - explicit operator float() const { return bf8_to_float_raw(data); } + explicit constexpr constexpr operator float() const { return bf8_to_float_raw(data); } // cast to int CK_TILE_HOST_DEVICE - explicit operator int() const { return static_cast(bf8_to_float_raw(data)); } + explicit constexpr operator int() const { return static_cast(bf8_to_float_raw(data)); } // internal access CK_TILE_HOST_DEVICE - raw_type& get() { return data; } + constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE - raw_type get() const { return data; } + constexpr raw_type get() const { return data; } }; // below is sw fp8 conversion, not utilizing hw instruction diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 4a6fc59c28..02cf05a7d1 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -3,26 +3,29 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/limits.hpp" #include #pragma once namespace ck_tile { -CK_TILE_HOST_DEVICE -float fp16_to_float_hip(const _Float16& x); +using fp16_hip_t = __half; // most of hip internal function use this type CK_TILE_HOST_DEVICE -_Float16 float_to_fp16_hip(const float& x); +float fp16_to_float_hip(const fp16_hip_t& x); -// HIP use _Float16 as interchangable data type for float16 +CK_TILE_HOST_DEVICE +fp16_hip_t float_to_fp16_hip(const float& x); + +// HIP use fp16_hip_t as interchangable data type for float16 struct alignas(2) half_t { using raw_type = uint16_t; raw_type data; CK_TILE_HOST_DEVICE - static half_t bit_cast(raw_type x) + static constexpr half_t bit_cast(raw_type x) { half_t y; y.data = x; @@ -30,56 +33,62 @@ struct alignas(2) half_t } CK_TILE_HOST_DEVICE - _Float16 to_fp16() const { return reinterpret_cast(data); } + constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast(data); } // constructor - half_t() = default; + constexpr half_t() : data() {} // construct from HIP half CK_TILE_HOST_DEVICE - explicit half_t(const _Float16& x) : data(reinterpret_cast(x)) {} + explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast(x)) {} // construct from float CK_TILE_HOST_DEVICE - explicit half_t(const float& x) : half_t(float_to_fp16_hip(x)) {} + explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {} // construct from int CK_TILE_HOST_DEVICE - explicit half_t(const int& x) : half_t(__int2half_rn(x)) {} + explicit constexpr half_t(const int& x) : half_t(static_cast(__int2half_rn(x))) {} // construct from unsigned int CK_TILE_HOST_DEVICE - explicit half_t(const unsigned int& x) : half_t(__uint2half_rn(x)) {} + explicit constexpr half_t(const unsigned int& x) + : half_t(static_cast(__uint2half_rn(x))) + { + } // cast to float CK_TILE_HOST_DEVICE - explicit operator float() const { return fp16_to_float_hip(to_fp16()); } + explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); } // cast to int CK_TILE_HOST_DEVICE - explicit operator int() const { return static_cast(fp16_to_float_hip(to_fp16())); } + explicit constexpr operator int() const + { + return static_cast(fp16_to_float_hip(to_fp16())); + } // internal access CK_TILE_HOST_DEVICE - raw_type& get() { return data; } + constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE - raw_type get() const { return data; } + constexpr raw_type get() const { return data; } }; // conversions CK_TILE_HOST_DEVICE -float fp16_to_float_hip(const _Float16& x) +float fp16_to_float_hip(const fp16_hip_t& x) { // return __half2float(x); return static_cast(x); } CK_TILE_HOST_DEVICE -_Float16 float_to_fp16_hip(const float& x) +fp16_hip_t float_to_fp16_hip(const float& x) { // return __float2half(x); - return static_cast<_Float16>(x); + return static_cast(x); } CK_TILE_HOST_DEVICE diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 98b3775285..aa7c96b6e6 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" #include #include @@ -290,7 +291,7 @@ float abs(const float& x) CK_TILE_HOST_DEVICE bool isnan(const float& x) { - uint32_t xx = reinterpret_cast(x); + uint32_t xx = bit_cast(x); return (xx & 0x7fffffff) > 0x7F800000; } diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index bf75f9bffc..eca5efdffb 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -82,10 +82,10 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { @@ -99,7 +99,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::set) @@ -144,9 +144,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -159,7 +159,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { @@ -268,7 +268,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const { @@ -348,9 +348,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const { @@ -370,9 +370,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::set) @@ -399,10 +399,10 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -413,7 +413,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -463,9 +463,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) { using scalar_t = typename scalar_type>::type; @@ -480,14 +480,14 @@ struct buffer_view, int32_t> || is_same_v, float> || (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); -#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) +#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; -#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT +#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = is_same_v, float> || (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); @@ -512,9 +512,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -527,7 +527,7 @@ struct buffer_view>::type; bool constexpr use_amd_buffer_addressing = is_same_v, double>; #else @@ -628,10 +628,10 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { @@ -645,7 +645,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::set) @@ -690,9 +690,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -703,7 +703,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { @@ -916,7 +916,7 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::set) @@ -961,9 +961,9 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T @@ -976,7 +976,7 @@ struct buffer_view{p, buffer_size}; } -template < - address_space_enum BufferAddressSpace, - amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default, - typename T, - typename BufferSizeType, - typename X, - typename enable_if, remove_cvref_t>::value, bool>::type = false> +template , remove_cvref_t>::value, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) { @@ -1038,4 +1038,4 @@ make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) p, buffer_size, invalid_element_value}; } -} // namespace ck_tile \ No newline at end of file +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 54d937a8d0..35ef4ac405 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -14,7 +14,6 @@ #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { -namespace tile_program { template {}, number{}); constexpr auto up_dim_numbers_scan = merge_sequences( - Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); + sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { @@ -510,7 +511,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // shift constexpr index_t adaptor0_max_hidden_id = [&]() { - index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + index_t adaptor0_max_hidden_id_ = numeric_limits::min(); static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = @@ -536,7 +537,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }(); constexpr index_t adaptor1_min_hidden_id = [&]() { - index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + index_t adaptor1_min_hidden_id_ = numeric_limits::max(); static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = @@ -680,7 +681,9 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a remove_cvref_t>{all_transforms}; } -template = 2, bool>::type = false> +template = 2, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) { return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 697988de10..0ff9210e5f 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -363,7 +363,7 @@ make_naive_tensor_descriptor_packed(const tuple& lengths, constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = container_reduce(lengths, math::multiplies{}, long_number<1>{}); + const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); using GuaranteedVectorLengths = typename sequence_merge::type, @@ -392,8 +392,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offs number = number<-1>{}) { const auto desc_0 = [&]() { - const auto element_space_size = - container_reduce(lengths, math::multiplies{}, long_number<1>{}); + const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); @@ -442,7 +441,7 @@ make_naive_tensor_descriptor_aligned(const tuple& lengths, Align ali constexpr index_t N = sizeof...(Lengths); - const auto stride_n_minus_2 = math::integer_least_multiple(lengths[number{}], align); + const auto stride_n_minus_2 = integer_least_multiple(lengths[number{}], align); auto strides = generate_tuple( [&](auto i) { @@ -456,12 +455,8 @@ make_naive_tensor_descriptor_aligned(const tuple& lengths, Align ali } else { - return container_reduce(lengths, - math::multiplies{}, - number{}, - i + I1, - number{}, - I1); + return container_reduce( + lengths, multiplies{}, number{}, i + I1, number{}, I1); } }, number{}); diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 3309b4b442..9a10ca3af3 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -58,11 +58,12 @@ struct tensor_view #endif // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template >::type, - typename scalar_type>::type>, - bool>::type = false> + template < + typename X, + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_vectorized_elements(const TensorCoord& coord, bool_constant = {}) const @@ -75,11 +76,12 @@ struct tensor_view // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template >::type, - typename scalar_type>::type>, - bool>::type = false> + template < + typename X, + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>, + bool>::type = false> CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, const TensorCoord& coord, @@ -91,10 +93,11 @@ struct tensor_view coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); } - template >::type, - typename scalar_type>::type>, - bool>::type = false> + template < + typename X, + typename std::enable_if>::type, + typename scalar_type>::type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, const TensorCoord& coord) const { @@ -103,11 +106,12 @@ struct tensor_view // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template >::type, - typename scalar_type>::type>, - bool>::type = false> + template < + typename X, + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( const TensorCoord& coord, const X& x, bool_constant = {}) { @@ -117,11 +121,12 @@ struct tensor_view x); } - template >::type, - typename scalar_type>::type>, - bool>::type = false> + template < + typename X, + bool oob_conditional_check = true, + typename std::enable_if>::type, + typename scalar_type>::type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( const TensorCoord& coord, const X& x, bool_constant = {}) { @@ -172,9 +177,9 @@ template ::type = false> + index_t GuaranteedLastDimensionVectorLength = -1, + index_t GuaranteedLastDimensionVectorStride = -1, + typename std::enable_if::type = false> CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType* p, const tuple& lengths, @@ -245,8 +250,7 @@ pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, const auto tile_length = tile_lengths[idim]; - const auto new_length = - math::integer_divide_ceil(old_length, tile_length) * tile_length; + const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length; const auto pad_length = new_length - old_length; diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index c891e9a608..8a1a2e4810 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -89,7 +89,7 @@ struct tile_distribution return generate_tuple( [&](auto i) { constexpr index_t x_length = - container_reduce(typename DstrEncode::HsLengthss{}[i], math::multiplies{}, 1); + container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1); return number{}; }, @@ -530,7 +530,7 @@ struct reverse_slice_sequence_impl, static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value; static constexpr auto slice_length = - std::conditional_t, number>::value; + std::conditional_t, number>::value; using dim_lengths = typename sequence_merge, typename old_scan::dim_lengths>::type; @@ -557,7 +557,7 @@ struct reverse_slice_sequence_impl, sequence, sequence, Slice { static constexpr auto slice_size = SliceSize; static constexpr auto slice_length = - std::conditional_t, number>::value; + std::conditional_t, number>::value; using dim_lengths = sequence; using dim_slices = sequence; diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index f2f3707e6c..7b1e952025 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -71,7 +71,7 @@ struct tile_distribution_encoding // max_ndim_rh_minor_ static constexpr index_t max_ndim_rh_minor_ = - container_reduce(ndims_rhs_minor_, math::maximize{}, 0); + container_reduce(ndims_rhs_minor_, maximize{}, 0); // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_] static constexpr auto rhs_lengthss_ = @@ -122,7 +122,7 @@ struct tile_distribution_encoding // max_ndim_span_minor_ static constexpr index_t max_ndim_span_minor_ = - container_reduce(ndims_span_minor_, math::maximize{}, 0); + container_reduce(ndims_span_minor_, maximize{}, 0); // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_] static constexpr auto rhs_major_minor_to_span_minor_ = [] { @@ -293,8 +293,7 @@ struct tile_distribution_encoding template CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq) { - using sorted_idx = - sequence_unique_sort, math::equal>; + using sorted_idx = sequence_unique_sort, equal>; constexpr auto sorted_dims = typename sorted_idx::type{}; constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{}; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 7bbc61cef1..c246c9c456 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -4,9 +4,9 @@ #pragma once #include "ck_tile/core/config.hpp" -#include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/container/sequence.hpp" #include #include diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 9e1a7aa4c9..7cdc3d2a28 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -43,4 +43,30 @@ using is_known_at_compile_time = is_static; // , this helper will also return false, which is not good(?) // do we need something like is_constexpr()? +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template