This commit is contained in:
carlushuang
2024-03-03 23:48:31 +00:00
parent fbd25cea35
commit 112d521b09
66 changed files with 1720 additions and 1498 deletions

View File

@@ -72,7 +72,7 @@ auto get_elimit(int /*init_method*/)
}
template <>
auto get_elimit<ck_tile::bhalf_t>(int init_method)
auto get_elimit<ck_tile::bf16_t>(int init_method)
{
if(init_method == 0)
{
@@ -510,7 +510,7 @@ int main(int argc, char* argv[])
}
else if(data_type == "bf16")
{
return run<ck_tile::bhalf_t>(arg_parser) ? 0 : -2;
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8")
{

View File

@@ -7,7 +7,6 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/common.hpp"
#include "mask.hpp"
template <typename DataType>
@@ -29,18 +28,18 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bhalf_t>
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bhalf_t;
using KDataType = ck_tile::bhalf_t;
using VDataType = ck_tile::bhalf_t;
using BiasDataType = ck_tile::bhalf_t;
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bhalf_t;
using ODataType = ck_tile::bf16_t;
};
template <>

View File

@@ -11,7 +11,7 @@ import copy
DTYPE_MAP = {
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bhalf_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}

View File

@@ -6,8 +6,9 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_address_space.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/container_helper.hpp"
@@ -51,6 +52,6 @@
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_convert.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -1,20 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
enum struct address_space_enum
{
generic,
global,
lds,
sgpr,
vgpr,
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
enum struct address_space_enum
{
generic,
global,
lds,
sgpr,
vgpr,
};
enum struct memory_operation_enum
{
set,
atomic_add,
atomic_max,
add
};
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
// TODO: deprecate these
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
} // namespace ck_tile

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
// TODO: we have "memory" clobber here because this inline asm is used for async copy
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
{
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
}
// NOTE: this is an immediate value
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
{
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
} // namespace ck_tile

View File

@@ -120,3 +120,15 @@
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif

View File

@@ -44,6 +44,19 @@ struct array
data[i] = vlast;
}
}
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
{
for(auto i = 0; i < size(); i++)
data[i] = c;
}
template <typename ArrayType>
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
{
static_assert(ArrayType::size() == size(), "wrong! size not the same");
for(auto i = 0; i < size(); i++)
data[i] = o.data[i];
}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
@@ -67,18 +80,18 @@ struct array
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a)
#if 0
template <typename ArrayType>
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayType& a)
{
static_assert(T::size() == size(), "wrong! size not the same");
static_assert(ArrayType::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = a[i];
}
return *this;
}
#endif
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
@@ -122,6 +135,17 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
template <typename T, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
{

View File

@@ -468,6 +468,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
number<Seq::size()>{});
}
#if 0
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
@@ -479,5 +480,21 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
}, \
ck_tile::number<a_size>{}); \
}()
#else
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#endif
} // namespace ck_tile

View File

@@ -67,13 +67,12 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto get(number<I>)
{
static_assert(I < size(), "wrong! I too large");
return number<get(I)>{};
return number<get<I>()>{};
}
CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
static_assert(I < size(), "wrong! I too large");
const index_t mData[size() + 1] = {Is..., 0};
return mData[I];
}
@@ -89,7 +88,7 @@ struct sequence
CK_TILE_HOST_DEVICE static constexpr auto at(number<I>)
{
static_assert(I < size(), "wrong! I too large");
return number<get(I)>{};
return number<get<I>()>{};
}
template <typename I>

View File

@@ -16,40 +16,48 @@ namespace ck_tile {
namespace impl {
// the place where content is stored
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
struct tuple_element
struct tuple_object
{
};
template <index_t idx, typename T>
struct tuple_element<idx, T, true>
struct tuple_object<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {}
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {}
};
template <index_t idx, typename T>
struct tuple_element<idx, T, false>
struct tuple_object<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {}
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {}
T element;
};
// NOTE: we return a instance(not a reference) if content is empty
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element<I, T, false> const& x)
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
{
return {};
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element<I, T, false>& x)
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
{
return x.element;
}
template <std::size_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element<I, T, false>&& x)
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
}
@@ -58,18 +66,18 @@ template <typename index_seq, typename... T>
struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : public tuple_element<I, T>...
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element<I, T>(u)...
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...> const& u)
: tuple_element<I, T>(getv(static_cast<tuple_element<I, U> const&>(u)))...
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
{
}
};
@@ -84,15 +92,13 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
CK_TILE_HOST_DEVICE constexpr tuple() {}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U const&... u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(u...)
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...> const& u)
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(
static_cast<impl::tuple_base<U...> const&>(u))
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
@@ -109,19 +115,19 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// clang-format on
#undef TP_COM_
};
@@ -250,15 +256,15 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element)
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
{
return element;
return t;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
{
return make_tuple(element);
return make_tuple(t);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
@@ -334,7 +340,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
index_t max_n1_ = 0;
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
constexpr index_t n1 = t_of_s[i0].size();
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
});
@@ -345,7 +351,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size()();
constexpr index_t n1 = t_of_s[i0].size();
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
});
@@ -482,3 +488,60 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
};
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) \
_Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif

View File

@@ -20,7 +20,8 @@ enum class bf16_rounding_mode
truncate,
};
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
@@ -41,36 +42,42 @@ struct alignas(2) bfloat16_t
}
// constructor
bfloat16_t() = default;
constexpr bfloat16_t() : data() {}
// construct from float
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
explicit constexpr bfloat16_t(const unsigned int& x)
: data(float_to_bf16_raw(static_cast<float>(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit operator float() const { return bf16_to_float_raw(data); }
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
raw_type& get() { return data; }
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
raw_type get() const { return data; }
constexpr raw_type get() const { return data; }
};
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
// round to nearest
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_rtn_raw(float f)
@@ -139,8 +146,8 @@ uint16_t float_to_bf16_truc_raw(float f)
return uint16_t(u.int32 >> 16);
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
@@ -161,8 +168,9 @@ float bf16_to_float_raw(uint16_t x)
return u.fp32;
}
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
{
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
}
@@ -170,14 +178,15 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
template <class T>
struct numeric_limits;
@@ -259,6 +268,4 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
using bf16_t = bfloat16_t;
} // namespace ck_tile

View File

@@ -7,6 +7,7 @@
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include <stdint.h>
@@ -75,20 +76,19 @@ struct alignas(1) float8_e4m3_t
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const int& x)
explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
{
data = float_to_fp8_raw(static_cast<float>(x));
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const unsigned int& x)
: data(float_to_fp8_raw(static_cast<float>(x)))
{
data = float_to_fp8_raw(static_cast<float>(x));
}
// cast to float
@@ -106,6 +106,8 @@ struct alignas(1) float8_e4m3_t
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
using fp8_t = float8_e4m3_t;
using fp8_raw_t = typename fp8_t::raw_type;
struct alignas(1) float8_e5m2_t
{
@@ -132,25 +134,24 @@ struct alignas(1) float8_e5m2_t
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const int& x)
explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
{
data = float_to_bf8_raw(static_cast<float>(x));
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const unsigned int& x)
: data(float_to_bf8_raw(static_cast<float>(x)))
{
data = float_to_bf8_raw(static_cast<float>(x));
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr constexpr operator float() const { return bf8_to_float_raw(data); }
explicit constexpr operator float() const { return bf8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
@@ -163,6 +164,8 @@ struct alignas(1) float8_e5m2_t
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
using bf8_t = float8_e5m2_t;
using bf8_raw_t = typename bf8_t::raw_type;
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
@@ -431,10 +434,10 @@ CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
}
} // namespace impl
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
@@ -453,16 +456,18 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
@@ -479,13 +484,15 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
@@ -506,12 +513,14 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
@@ -530,30 +539,32 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return impl::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
x, rng);
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
// clang-format off
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
else return uint8_t{0};
else return fp8_raw_t{0};
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return uint8_t{0};
else return bf8_raw_t{0};
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -563,11 +574,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(fp8_t::bit_cast(x));
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -577,18 +588,18 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bf8_t::bit_cast(x));
#endif
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding>)
{
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding>)
{
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
}
@@ -604,8 +615,6 @@ CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
}
// clang-format on
using fp8_t = float8_e4m3_t;
using bf8_t = float8_e5m2_t;
template <typename T>
struct numeric_utils;

View File

@@ -76,6 +76,9 @@ struct alignas(2) half_t
constexpr raw_type get() const { return data; }
};
using fp16_t = half_t;
using fp16_raw_t = typename fp16_t::raw_type;
// conversions
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x)
@@ -282,6 +285,4 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
using fp16_t = half_t;
} // namespace ck_tile

View File

@@ -147,6 +147,9 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound);
}
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
{
@@ -222,7 +225,7 @@ struct less
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return 1 << (32 - __builtin_clz(x - 1));
return 1 << (32 - clz(x - 1));
}
template <index_t X>
@@ -243,7 +246,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - __builtin_clz(x);
return 31 - clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)

View File

@@ -3,10 +3,22 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace ck_tile {
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
{
return static_cast<Y>(x);
}
#if 0
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
@@ -15,11 +27,9 @@ template <typename Y,
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// TODO: const version never called, we may never need
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
@@ -28,18 +38,29 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#else
// compatible way to call conversion operator and constructor of each custom data type
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)
CK_TILE_TYPE_CONVERT(float, bf16_t)
CK_TILE_TYPE_CONVERT(float, fp8_t)
CK_TILE_TYPE_CONVERT(float, bf8_t)
CK_TILE_TYPE_CONVERT(fp16_t, float)
CK_TILE_TYPE_CONVERT(bf16_t, float)
CK_TILE_TYPE_CONVERT(fp8_t, float)
CK_TILE_TYPE_CONVERT(bf8_t, float)
#undef CK_TILE_TYPE_CONVERT
#endif
} // namespace ck_tile

View File

@@ -10,295 +10,112 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// TODO: the whole content of this file should consider deprecated!
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_>
struct vector_type
struct ext_vector
{
static constexpr index_t N = N_;
using value_type = T_;
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
CK_HOST_DEVICE constexpr vector_type()
{
for(auto i = 0; i < N; i++)
data[i] = static_cast<value_type>(0);
}
CK_HOST_DEVICE constexpr vector_type(type v)
{
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
for(auto i = 0; i < N; i++)
data[i] = r.get(i);
}
value_type data[N];
CK_HOST_DEVICE static constexpr auto size() { return N; }
CK_HOST_DEVICE auto& get() { return data; }
CK_HOST_DEVICE const auto& get() const { return data; }
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& operator[](number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& operator[](number<I>) const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& operator()(number<I>)
{
return data[I];
}
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
template <index_t I>
CK_HOST_DEVICE auto& at()
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at() const
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE auto& at(number<I>)
{
return data[I];
}
template <index_t I>
CK_HOST_DEVICE const auto& at(number<I>) const
{
return data[I];
}
#define _VT_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx>
CK_HOST_DEVICE auto& get_as()
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as() const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data);
}
template <typename Tx>
CK_HOST_DEVICE auto& get_as(index_t i)
{
_VT_COMMON_AS();
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
}
template <typename Tx>
CK_HOST_DEVICE const auto& get_as(index_t i) const
{
_VT_COMMON_AS();
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
}
#undef _VT_COMMON_AS
};
} // namespace impl
template <typename T, index_t N>
struct vector_type_maker
using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T>
struct vector_traits
{
using type = vector_type<T, N>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
{
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N0, index_t N1>
struct vector_type_maker<vector_type<T, N1>, N0>
{
using type = vector_type<T, N0 * N1>;
using scalar_type = remove_cvref_t<T>;
static constexpr index_t vector_size = 1;
};
// specialization for ext_vector_type()
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
// is_scalar_type
template <typename TV>
struct is_scalar_type
{
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
};
// has_same_scalar_type
template <typename X, typename Y>
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<Y>>::type>;
template <typename T, index_t N>
struct scalar_type<T __attribute__((ext_vector_type(N)))>
{
using type = T;
using scalar_type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct scalar_type<vector_type<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
//
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<float>
{
using type = float;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<half_t>
{
using type = half_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bhalf_t>
{
using type = bhalf_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int64_t>
{
using type = int64_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int32_t>
{
using type = int32_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<int8_t>
{
using type = int8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
{
using type = int4_t;
static constexpr index_t vector_size = 1;
};
#endif
template <>
struct scalar_type<fp8_t>
{
using type = fp8_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
{
using type = bf8_t;
static constexpr index_t vector_size = 1;
};
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8)));
using fp32x16_t = float __attribute__((ext_vector_type(16)));
using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// i16
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// f8
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
} // namespace ck_tile

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/amd_address_space.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
@@ -12,6 +12,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -22,13 +23,13 @@ namespace ck_tile {
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
// BufferView definition for different memory address space (Global/GenericLds/Vgpr)
// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default>
struct BufferView;
struct buffer_view;
// Address Space: generic
// T may be scalar or vector
@@ -82,17 +83,18 @@ struct buffer_view<address_space_enum::generic,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -123,19 +125,20 @@ struct buffer_view<address_space_enum::generic,
}
// i is offset of T, not X. i should be aligned to X
template <InMemoryDataOperationEnum Op,
template <memory_operation_enum Op,
typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum::set)
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove InMemoryDataOperationEnum::Add
else if constexpr(Op == InMemoryDataOperationEnum::Add)
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
@@ -144,15 +147,16 @@ struct buffer_view<address_space_enum::generic,
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -253,17 +257,18 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -326,16 +331,17 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
{
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -348,15 +354,16 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -368,27 +375,28 @@ struct buffer_view<address_space_enum::global,
}
// i is offset of T, not X. i should be aligned to X
template <InMemoryDataOperationEnum Op,
template <memory_operation_enum Op,
typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum::set)
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum::atomic_add)
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, is_valid_element, x);
}
else if constexpr(Op == InMemoryDataOperationEnum::atomic_max)
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, is_valid_element, x);
}
// FIXME: remove InMemoryDataOperationEnum::Add
else if constexpr(Op == InMemoryDataOperationEnum::Add)
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
@@ -399,16 +407,17 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -443,16 +452,17 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -463,17 +473,18 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
{
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -482,15 +493,16 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -512,15 +524,16 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -528,8 +541,8 @@ struct buffer_view<address_space_enum::global,
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -628,17 +641,18 @@ struct buffer_view<address_space_enum::lds,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -669,19 +683,20 @@ struct buffer_view<address_space_enum::lds,
}
// i is offset of T, not X. i should be aligned to X
template <InMemoryDataOperationEnum Op,
template <memory_operation_enum Op,
typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum::set)
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove InMemoryDataOperationEnum::Add
else if constexpr(Op == InMemoryDataOperationEnum::Add)
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
@@ -690,15 +705,16 @@ struct buffer_view<address_space_enum::lds,
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -709,7 +725,8 @@ struct buffer_view<address_space_enum::lds,
bool constexpr workaround_int8_ds_write_issue = false;
#endif
if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t>::value &&
workaround_int8_ds_write_issue)
{
if(is_valid_element)
@@ -718,83 +735,83 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value),
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8_t>::value)
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x2_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
is_same<remove_cvref_t<X>, int8x16_t>::value)
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
@@ -899,17 +916,18 @@ struct buffer_view<address_space_enum::vgpr,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -940,19 +958,20 @@ struct buffer_view<address_space_enum::vgpr,
}
// i is offset of T, not X. i should be aligned to X
template <InMemoryDataOperationEnum Op,
template <memory_operation_enum Op,
typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum::set)
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove InMemoryDataOperationEnum::Add
else if constexpr(Op == InMemoryDataOperationEnum::Add)
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
@@ -961,15 +980,16 @@ struct buffer_view<address_space_enum::vgpr,
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -1029,7 +1049,7 @@ template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
typename X,
typename std::enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)

View File

@@ -11,6 +11,9 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
namespace ck_tile {
@@ -65,13 +68,13 @@ CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const NullTileWindow<WindowLengths>&)
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{
return NullTensor{};
return null_tensor{};
}
template <typename T, typename WindowLengths>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow<WindowLengths>&)
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
{
}

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace ck_tile {

View File

@@ -8,11 +8,13 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace ck_tile {
namespace detail {
@@ -74,8 +76,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = vector_type<DataType, vec_length_in>;
using OutVec = vector_type<DataType, vec_length_out>;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
using InVecType = typename InVec::type;
using OutVecType = typename OutVec::type;
@@ -114,7 +116,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
in_vectors(i).template AsType<InVecType>()(I0) =
in_vectors(i).template get_as<InVecType>()(I0) =
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
});
@@ -134,7 +136,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
out_tensor.get_thread_buffer().template set_as<OutVecType>(
number<out_offset / sizeof(OutVecType)>{},
out_vectors[i].template AsType<OutVecType>()[I0]);
out_vectors[i].template get_as<OutVecType>()[I0]);
});
});
}

View File

@@ -84,7 +84,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
}

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {

View File

@@ -25,7 +25,7 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
@@ -48,7 +48,7 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};

View File

@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -725,7 +725,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
STATIC_ASSERT(name == cood_transform_enum::pass_through || \
name == cood_transform_enum::pad || \
name == cood_transform_enum::embed || \
name == cood_transform_enum::merge || \
@@ -733,7 +733,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
name == cood_transform_enum::replicate, \
""); \
\
if constexpr(name == cood_transform_enum::PassThrough) \
if constexpr(name == cood_transform_enum::pass_through) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
@@ -849,7 +849,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
STATIC_ASSERT(name == cood_transform_enum::pass_through || \
name == cood_transform_enum::pad || \
name == cood_transform_enum::embed || \
name == cood_transform_enum::merge || \
@@ -857,7 +857,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
name == cood_transform_enum::replicate, \
""); \
\
if constexpr(name == cood_transform_enum::PassThrough) \
if constexpr(name == cood_transform_enum::pass_through) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \

View File

@@ -299,7 +299,7 @@ template <typename... Lengths,
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
const offset& offset,
const offset& os,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
@@ -307,7 +307,7 @@ make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
const auto element_space_size = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
@@ -383,12 +383,12 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
template <typename... Lengths,
typename... Strides,
typename offset,
typename Offset,
index_t GuaranteedLastDimensionVectorLength = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
const tuple<Lengths...>& lengths,
const offset& offset,
const Offset& offset,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
const auto desc_0 = [&]() {

View File

@@ -3,12 +3,15 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -16,15 +19,16 @@ namespace ck_tile {
template <typename BufferView_, typename TensorDesc_>
struct tensor_view
{
using BufferView = remove_reference_t<BufferView_>;
using DataType = typename BufferView::type;
using buffer_view = remove_reference_t<BufferView_>;
using DataType = typename buffer_view::type;
using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc)
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
const TensorDesc& desc)
: buf_{buffer_view}, desc_{desc}
{
}
@@ -58,12 +62,12 @@ struct tensor_view
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <
typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
@@ -76,12 +80,12 @@ struct tensor_view
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <
typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
@@ -93,11 +97,11 @@ struct tensor_view
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <
typename X,
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
@@ -106,12 +110,12 @@ struct tensor_view
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <
typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
@@ -121,12 +125,12 @@ struct tensor_view
x);
}
template <
typename X,
bool oob_conditional_check = true,
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<DataType>>::type>,
bool>::type = false>
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
@@ -153,7 +157,7 @@ struct tensor_view
}
// member
BufferView buf_;
buffer_view buf_;
TensorDesc desc_;
};
@@ -162,7 +166,7 @@ struct null_tensor_view
{
};
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
@@ -173,7 +177,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
typename... Strides,
@@ -197,7 +201,7 @@ make_naive_tensor_view(DataType* p,
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
@@ -228,19 +232,19 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::BufferView, remove_cvref_t<decltype(new_desc)>>{
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
old_tensor_view.buf_, new_desc};
}
template <typename tensor_view,
template <typename TensorView,
typename TileLengths, // tuple<...>
typename DoPads> // sequence<bool, bool, ...>
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads)
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(),
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
"wrong! inconsistent # of dimensions");
// transforms

View File

@@ -3,12 +3,14 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -293,7 +295,9 @@ CK_TILE_HOST_DEVICE constexpr auto
&hidden_dim_cnt,
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
constexpr auto h_minor_lengths = tuple_element_t<idim_x, HsLengthss>{};
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -18,7 +19,7 @@ namespace ck_tile {
template <typename InOutElementFunc,
typename... InOutDstrTensors,
typename = std::enable_if_t<std::conjunction_v<
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, NullTensor>>...>>>
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
InOutDstrTensors&... inout_dstr_tensors)
{
@@ -26,7 +27,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
// static_assert(xxx);
constexpr index_t thread_buffer_size =
type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
static_for<0, thread_buffer_size, 1>{}(
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
@@ -35,7 +36,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
template <typename InElementFunc,
typename... InDstrTensors,
typename = std::enable_if_t<
std::conjunction_v<std::negation<std::is_same<InDstrTensors, NullTensor>>...>>>
std::conjunction_v<std::negation<std::is_same<InDstrTensors, null_tensor>>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InDstrTensors&... in_dstr_tensors)
{
@@ -43,10 +44,10 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution();
constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution();
constexpr index_t thread_buffer_size =
type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
__type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
@@ -69,7 +70,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
}
template <typename T>
CK_TILE_DEVICE void set_tile(NullTensor&, const T&)
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
{
}
@@ -82,7 +83,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
{
using dvec_t = static_buffer_c<index_t, tensor_bytes / 4>;
using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v;
@@ -96,7 +97,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
}
template <index_t v>
CK_TILE_DEVICE void set_tile(NullTensor&, number<v>)
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
{
}
@@ -139,7 +140,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
false); // false -> WORD0
constexpr int32_t m0 = 0x05040100;
using vec_t = typename vector_type<OutDataType, 4>::type;
using vec_t = array<OutDataType, 4>;
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
@@ -157,9 +158,9 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
template <typename DstType, typename SrcDstrTensors>
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
{
if constexpr((ck_tile::is_same_v<DstType, fp8_t> ||
ck_tile::is_same_v<DstType, bf8_t>)&&ck_tile::
is_same_v<typename SrcDstrTensors::DataType, float> &&
if constexpr((std::is_same_v<DstType, fp8_t> ||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcDstrTensors::DataType,
float> &&
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
{
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
@@ -169,23 +170,23 @@ CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
src_tensor);
}
// no-op function for NullTensor arguments
// no-op function for null_tensor arguments
template <typename InOutElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
{
}
// no-op function for NullTensor arguments
// no-op function for null_tensor arguments
template <typename InElementFunc,
typename... MaybeNullTensor,
typename = std::enable_if_t<
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
{
return NullTensor{};
return null_tensor{};
}
} // namespace ck_tile

View File

@@ -3,12 +3,15 @@
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -85,8 +88,9 @@ struct tile_window_with_static_distribution
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = array<DataType, ScalarPerVector>;
private:
static constexpr auto scalars_per_access_ = [] {
@@ -275,9 +279,8 @@ struct tile_window_with_static_distribution
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
@@ -300,8 +303,6 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
const vector_type_t vec{vec_value};
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
@@ -315,7 +316,7 @@ struct tile_window_with_static_distribution
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec.template AsType<DataType>()[j];
vec_value.template get_as<DataType>()[j];
});
// move thread coordinate
@@ -341,16 +342,17 @@ struct tile_window_with_static_distribution
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % Traits::ScalarPerVector == 0);
using vectorized_tbuf = StaticBuffer<address_space_enum::vgpr,
vector_t,
YElementSize / Traits::ScalarPerVector,
true>;
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// true>;
constexpr auto tile_dstr = TileDstr{};
@@ -426,9 +428,9 @@ struct tile_window_with_static_distribution
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
@@ -468,9 +470,9 @@ struct tile_window_with_static_distribution
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
@@ -487,7 +489,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_type_t vec;
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
@@ -500,11 +503,11 @@ struct tile_window_with_static_distribution
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec.template AsType<DataType>()(j) =
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
@@ -530,9 +533,9 @@ struct tile_window_with_static_distribution
{
using Traits = load_store_traits;
using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename vector_type_t::type;
using SFC_Ys = typename Traits::SFC_Ys;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static constexpr bool oob_conditional_check = true;
@@ -550,7 +553,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_type_t vec;
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
@@ -563,11 +567,11 @@ struct tile_window_with_static_distribution
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec.template AsType<DataType>()(j) =
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view()

View File

@@ -191,4 +191,18 @@ CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
// z = predicate ? x : y
template <bool predicate, typename X, typename Y>
constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
}
else
{
return std::forward<Y>(y);
}
}
} // namespace ck_tile

View File

@@ -9,12 +9,13 @@
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... Is>( \
ck_tile::sequence<Is...>) \
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<Is>{})...>{}; \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else
// Macro function

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// S: scalar type (or it can be non-scalar type)
// NX: # of vector before transpose
// NY: # of vector after transpose
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
template <typename S_, index_t NX, index_t NY>
struct transpose_vectors
{
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = remove_cvref_t<S_>;
using VX = array<S, s_per_x>;
using VY = array<S, s_per_y>;
CK_TILE_DEVICE void operator()(const array<VX, NX>& vx_tuple, array<VY, NY>& vy_tuple)
{
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
if constexpr(sizeof(S) == 2)
{
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
using S2 = array<S, 2>; // typename array<S, 2>::type;
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 2>{}([&](auto iy) {
static_for<0, NX, 2>{}([&](auto ix) {
// 2 16bitx2 data from vx_tuple to be transposed
const int32_t x_s2_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
const int32_t x_s2_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
// 2 16bitx2 data after transposed
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
});
});
}
else if constexpr(sizeof(S) == 1)
{
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
using S4 = array<S, 4>; // typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// 4 int8x4 data from vx_tuple
const int32_t x_s4_0 =
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
const int32_t x_s4_1 =
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
const int32_t x_s4_2 =
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
const int32_t x_s4_3 =
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
// transpose
int32_t t_s4_0, t_s4_1;
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
// 4 int8x4 data from vy_tuple
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
});
});
}
else
{
static_assert(false, "not implemented");
}
}
};
} // namespace ck_tile

View File

@@ -1,57 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace ck_tile {
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
template <> \
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)
CK_TILE_TYPE_CONVERT(float, bf16_t)
CK_TILE_TYPE_CONVERT(float, fp8_t)
CK_TILE_TYPE_CONVERT(float, bf8_t)
CK_TILE_TYPE_CONVERT(fp16_t, float)
CK_TILE_TYPE_CONVERT(bf16_t, float)
CK_TILE_TYPE_CONVERT(fp8_t, float)
CK_TILE_TYPE_CONVERT(bf8_t, float)
} // namespace ck_tile

View File

@@ -69,4 +69,18 @@ struct nonesuch
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
// FIXME: do we need this anymore?
template <
typename PY,
typename PX,
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
} // namespace ck_tile

View File

@@ -97,7 +97,7 @@ check_err(const Range& out,
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
@@ -123,7 +123,7 @@ check_err(const Range& out,
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
@@ -214,7 +214,7 @@ check_err(const Range& out,
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
!std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif

View File

@@ -147,8 +147,8 @@ float launch_and_time_kernel_with_preprocess(const stream_config& s,
#endif
}
template <int MaxThreadPerBlock = CK_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_MIN_BLOCK_PER_CU,
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
float launch_kernel(const stream_config& s,

View File

@@ -0,0 +1,4 @@
## common
this folder is designed not to be included directly by use, e.g. if use include `ck_tile/ops/fmha.hpp`, then everything under `common` should also be included.
to achieve this we will duplicate the header include path under `common` to other module under `ops/*` inside remod.py. for internal developer, you can also include `ck_tile/ops/common.hpp` for convenience. (and so does external users...)

View File

@@ -4,4 +4,5 @@
#pragma once
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -17,4 +17,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
@@ -50,7 +51,7 @@ struct FmhaFwdKernel
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::half_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bhalf_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
@@ -79,7 +80,7 @@ struct FmhaFwdKernel
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" );
#undef _SS_
#undef _TS_
@@ -407,7 +408,7 @@ struct FmhaFwdKernel
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
@@ -519,7 +520,7 @@ struct FmhaFwdKernel
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,

View File

@@ -49,12 +49,12 @@ struct BlockFmhaPipelineProblem
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kIsFp8 =
(is_same_v<QDataType, fp8_t> || is_same_v<QDataType, bf8_t>)&&(
is_same_v<KDataType, fp8_t> ||
is_same_v<KDataType, bf8_t>)&&(is_same_v<VDataType, fp8_t> ||
is_same_v<VDataType, bf8_t>)&&is_same_v<SaccDataType,
float> &&
is_same_v<OaccDataType, float>;
(std::is_same_v<QDataType, fp8_t> || std::is_same_v<QDataType, bf8_t>)&&(
std::is_same_v<KDataType, fp8_t> ||
std::is_same_v<KDataType, bf8_t>)&&(std::is_same_v<VDataType, fp8_t> ||
std::is_same_v<VDataType, bf8_t>)&&std::
is_same_v<SaccDataType, float> &&
std::is_same_v<OaccDataType, float>;
};
} // namespace ck_tile

View File

@@ -56,7 +56,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
@@ -127,9 +127,9 @@ struct BlockFmhaPipelineQRKSVS
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -442,7 +442,7 @@ struct BlockFmhaPipelineQRKSVS
});
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -471,8 +471,7 @@ struct BlockFmhaPipelineQRKSVS
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
namespace ck_tile {
@@ -59,7 +59,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
@@ -138,9 +138,9 @@ struct BlockFmhaPipelineQRKSVSAsync
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -415,7 +415,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -539,8 +539,7 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(ck_tile::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -56,7 +56,7 @@ struct BlockFmhaPipelineQRKSVSFp8
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
@@ -119,9 +119,9 @@ struct BlockFmhaPipelineQRKSVSFp8
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -425,7 +425,7 @@ struct BlockFmhaPipelineQRKSVSFp8
});
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -453,8 +453,7 @@ struct BlockFmhaPipelineQRKSVSFp8
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -114,9 +114,9 @@ struct BlockFmhaPipelineQSKSVS
void* smem_ptr) const
{
static_assert(
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -434,7 +434,7 @@ struct BlockFmhaPipelineQSKSVS
});
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
@@ -463,8 +463,7 @@ struct BlockFmhaPipelineQSKSVS
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(ck_tile::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());

View File

@@ -4,7 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
@@ -76,24 +80,24 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(is_same_v<typename Problem::QDataType, half_t> &&
is_same_v<typename Problem::KDataType, half_t> &&
is_same_v<typename Problem::SaccDataType, float>)
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(is_same_v<typename Problem::QDataType, bhalf_t> &&
is_same_v<typename Problem::KDataType, bhalf_t> &&
is_same_v<typename Problem::SaccDataType, float>)
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(Problem::kIsFp8)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
2,
@@ -201,24 +205,24 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(is_same_v<typename Problem::QDataType, half_t> &&
is_same_v<typename Problem::KDataType, half_t> &&
is_same_v<typename Problem::SaccDataType, float>)
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(is_same_v<typename Problem::QDataType, bhalf_t> &&
is_same_v<typename Problem::KDataType, bhalf_t> &&
is_same_v<typename Problem::SaccDataType, float>)
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
}
else if constexpr(Problem::kIsFp8)
{
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
return warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::QDataType,
typename Problem::KDataType>,
2,
@@ -337,7 +341,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
@@ -762,7 +766,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
@@ -857,7 +861,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
// This descriptor only used when V layout is seqlen * hdim
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
static_assert(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
@@ -914,15 +918,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
auto warp_gemm = [&]() {
if constexpr(Problem::kIsFp8)
{
return warp::WarpGemmImpl<
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
return WarpGemmImpl<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
typename Problem::PDataType,
typename Problem::VDataType>,
2>>{};
// return
// warp::WarpGemmImpl<warp::WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
}
else

View File

@@ -28,3 +28,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -44,9 +44,9 @@ struct BlockGemmARegBGmemCRegV1
void* smem_ptr) const
{
static_assert(
is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
@@ -90,9 +90,10 @@ struct BlockGemmARegBGmemCRegV1
const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
void* smem_ptr) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
"wrong!");
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}];

View File

@@ -28,10 +28,11 @@ struct BlockGemmARegBSmemCRegV1
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
@@ -126,9 +127,9 @@ struct BlockGemmARegBSmemCRegV1
// check C-block-distribution
static_assert(
is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
@@ -184,9 +185,10 @@ struct BlockGemmARegBSmemCRegV1
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];

View File

@@ -14,9 +14,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(is_same_v<typename Problem::ADataType, half_t> &&
is_same_v<typename Problem::BDataType, half_t> &&
is_same_v<typename Problem::CDataType, float>)
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -43,9 +43,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(is_same_v<typename Problem::ADataType, bhalf_t> &&
is_same_v<typename Problem::BDataType, bhalf_t> &&
is_same_v<typename Problem::CDataType, float>)
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}

View File

@@ -28,10 +28,11 @@ struct BlockGemmARegBSmemCRegV2
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
@@ -126,9 +127,9 @@ struct BlockGemmARegBSmemCRegV2
// check C-block-distribution
static_assert(
is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;

View File

@@ -28,9 +28,9 @@ struct BlockGemmASmemBSmemCRegV1
const ABlockWindowTmp& a_block_window_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
is_same_v<CDataType, typename CBlockTensor::DataType>,
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];

View File

@@ -14,9 +14,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(is_same_v<typename Problem::ADataType, half_t> &&
is_same_v<typename Problem::BDataType, half_t> &&
is_same_v<typename Problem::CDataType, float>)
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -42,9 +42,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else if constexpr(is_same_v<typename Problem::ADataType, bhalf_t> &&
is_same_v<typename Problem::BDataType, bhalf_t> &&
is_same_v<typename Problem::CDataType, float>)
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}

View File

@@ -47,8 +47,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
void* p_smem) const
{
static_assert(
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&

View File

@@ -47,8 +47,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
void* p_smem) const
{
static_assert(
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
namespace ck_tile {

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
namespace ck_tile {
@@ -74,8 +75,8 @@ struct WarpGemmAtrributeMfmaIterateK
using BDataType = typename Impl::BDataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
@@ -111,33 +112,27 @@ struct WarpGemmAtrributeMfmaIterateK
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = number<0>{};
// c = a * b
auto c_vec = Impl{}(a_vector.template AsType<typename Impl::AVecType>()[I0],
b_vector.template AsType<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
b_vec.template get_as<typename Impl::BVecType>()[I0]);
// c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
@@ -274,8 +269,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
@@ -311,34 +306,27 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
a_vector.template AsType<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
b_vec.template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;
@@ -355,8 +343,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
@@ -418,34 +406,27 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
// swap A and B, value and type
static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
constexpr auto I0 = number<0>{};
// swap A and B, value and type
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
a_vector.template AsType<typename Impl::BVecType>()[I0]);
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
a_vec.template get_as<typename Impl::BVecType>()[I0]);
static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
});
return c_vec;

View File

@@ -10,13 +10,13 @@ namespace ck_tile {
// FP16
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using ADataType = half_t;
using BDataType = half_t;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = typename vector_type<half_t, 4>::type;
using BVecType = typename vector_type<half_t, 4>::type;
using CVecType = typename vector_type<float, 16>::type;
using AVecType = array<fp16_t, 4>;
using BVecType = array<fp16_t, 4>;
using CVecType = array<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -36,25 +36,37 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
fp32x16_t{0.f},
0,
0,
0));
}
};
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using ADataType = half_t;
using BDataType = half_t;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = typename vector_type<half_t, 4>::type;
using BVecType = typename vector_type<half_t, 4>::type;
using CVecType = typename vector_type<float, 4>::type;
using AVecType = array<fp16_t, 4>;
using BVecType = array<fp16_t, 4>;
using CVecType = array<float, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
@@ -74,26 +86,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec.template get_as<fp32x4_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x4_t>()[number<0>{}],
0,
0,
0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
b_vec.template get_as<fp16x4_t>()[number<0>{}],
fp32x4_t{0.f},
0,
0,
0));
}
};
// Bf16
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
using ADataType = bhalf_t;
using BDataType = bhalf_t;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = typename vector_type<bhalf_t, 4>::type;
using BVecType = typename vector_type<bhalf_t, 4>::type;
using CVecType = typename vector_type<float, 16>::type;
using AVecType = array<bf16_t, 4>;
using BVecType = array<bf16_t, 4>;
using CVecType = array<float, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -113,25 +137,37 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
fp32x16_t{0.f},
0,
0,
0));
}
};
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
using ADataType = bhalf_t;
using BDataType = bhalf_t;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = typename vector_type<bhalf_t, 4>::type;
using BVecType = typename vector_type<bhalf_t, 4>::type;
using CVecType = typename vector_type<float, 4>::type;
using AVecType = array<bf16_t, 4>;
using BVecType = array<bf16_t, 4>;
using CVecType = array<float, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
@@ -151,13 +187,25 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec.template get_as<fp32x4_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
c_vec.template get_as<fp32x4_t>()[number<0>{}],
0,
0,
0);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a_vec.template get_as<bf16x4_t>()[number<0>{}],
b_vec.template get_as<bf16x4_t>()[number<0>{}],
fp32x4_t{0.f},
0,
0,
0));
}
};
@@ -169,9 +217,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
using BDataType = BType_;
using CDataType = float;
using AVecType = typename vector_type<ADataType, 8>::type;
using BVecType = typename vector_type<BDataType, 8>::type;
using CVecType = typename vector_type<CDataType, 16>::type;
using AVecType = array<ADataType, 8>;
using BVecType = array<BDataType, 8>;
using CVecType = array<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
@@ -192,27 +240,49 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec),
bit_cast<long>(b_vec),
c_vec.template get_as<fp32x16_t>()[number<0>{}],
0,
0,
0);
#else
vector_type<ADataType, 8> a_(a_vec);
vector_type<BDataType, 8> b_(b_vec);
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 = type_convert<float>(a_.template AsType<ADataType>()[number<k>{}]);
float b_f32 = type_convert<float>(b_.template AsType<BDataType>()[number<k>{}]);
float a_f32 = type_convert<float>(a_vec.template get_as<ADataType>()[number<k>{}]);
float b_f32 = type_convert<float>(b_vec.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x2f32(
a_f32, b_f32, c_vec.template get_as<fp32x16_t>()[number<0>{}], 0, 0, 0);
});
#endif
}
@@ -220,18 +290,18 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, fp8_t>)
return __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
else if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, bf8_t>)
return __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, fp8_t>)
return __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, bf8_t>)
return __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
}
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
@@ -29,14 +30,14 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
// bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
// fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };

View File

@@ -33,9 +33,9 @@ struct WarpGemmImpl
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
{
using AVec = typename vector_type<ADataType, AWarpTensor::get_thread_buffer_size()>::type;
using BVec = typename vector_type<BDataType, BWarpTensor::get_thread_buffer_size()>::type;
using CVec = typename vector_type<CDataType, CWarpTensor::get_thread_buffer_size()>::type;
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -53,9 +53,9 @@ struct WarpGemmImpl
{
CWarpTensor c;
using AVec = typename vector_type<ADataType, AWarpTensor::get_thread_buffer_size()>::type;
using BVec = typename vector_type<BDataType, BWarpTensor::get_thread_buffer_size()>::type;
using CVec = typename vector_type<CDataType, CWarpTensor::get_thread_buffer_size()>::type;
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};

View File

@@ -4,3 +4,4 @@
#pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -188,7 +188,7 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>;
static_assert(is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
static_assert(std::is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
// declare acc_tensor
constexpr auto acc_dstr =

View File

@@ -2,9 +2,11 @@ import pathlib
from pathlib import Path
import subprocess
import os
import copy
NS = 'ck_tile'
OPS = 'ops'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
@@ -54,6 +56,15 @@ class submodule_t:
f.write(f'#include \"{header_path}\"\n')
f.write('\n')
# print(self.m)
# restructure common
for k, v in self.m.items():
if k == OPS and OPS_COMMON in v.keys():
common_list = copy.deepcopy(v[OPS_COMMON])
# v.pop(OPS_COMMON)
for km in v.keys():
if km != OPS_COMMON:
v[km].extend(common_list)
for k, v in self.m.items():
if k == OPS:
for km, kv in v.items():