mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
fix xx
This commit is contained in:
@@ -72,7 +72,7 @@ auto get_elimit(int /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bhalf_t>(int init_method)
|
||||
auto get_elimit<ck_tile::bf16_t>(int init_method)
|
||||
{
|
||||
if(init_method == 0)
|
||||
{
|
||||
@@ -510,7 +510,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bhalf_t>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
@@ -29,18 +28,18 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bhalf_t>
|
||||
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using QDataType = ck_tile::bhalf_t;
|
||||
using KDataType = ck_tile::bhalf_t;
|
||||
using VDataType = ck_tile::bhalf_t;
|
||||
using BiasDataType = ck_tile::bhalf_t;
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bhalf_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -11,7 +11,7 @@ import copy
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bhalf_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp8" : "ck_tile::fp8_t"
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/arch/amd_address_space.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
@@ -51,6 +52,6 @@
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/type_convert.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct address_space_enum
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
sgpr,
|
||||
vgpr,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
61
include/ck_tile/core/arch/arch.hpp
Normal file
61
include/ck_tile/core/arch/arch.hpp
Normal file
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct address_space_enum
|
||||
{
|
||||
generic,
|
||||
global,
|
||||
lds,
|
||||
sgpr,
|
||||
vgpr,
|
||||
};
|
||||
|
||||
enum struct memory_operation_enum
|
||||
{
|
||||
set,
|
||||
atomic_add,
|
||||
atomic_max,
|
||||
add
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
|
||||
|
||||
// TODO: deprecate these
|
||||
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
// Use these instead
|
||||
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
|
||||
|
||||
CK_TILE_DEVICE index_t get_warp_id()
|
||||
{
|
||||
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
} // namespace ck_tile
|
||||
27
include/ck_tile/core/arch/utility.hpp
Normal file
27
include/ck_tile/core/arch/utility.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: we have "memory" clobber here because this inline asm is used for async copy
|
||||
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
|
||||
}
|
||||
|
||||
// NOTE: this is an immediate value
|
||||
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -120,3 +120,15 @@
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
@@ -44,6 +44,19 @@ struct array
|
||||
data[i] = vlast;
|
||||
}
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr array(value_type c)
|
||||
{
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = c;
|
||||
}
|
||||
template <typename ArrayType>
|
||||
CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o)
|
||||
{
|
||||
static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
for(auto i = 0; i < size(); i++)
|
||||
data[i] = o.data[i];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
|
||||
|
||||
@@ -67,18 +80,18 @@ struct array
|
||||
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a)
|
||||
#if 0
|
||||
template <typename ArrayType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayType& a)
|
||||
{
|
||||
static_assert(T::size() == size(), "wrong! size not the same");
|
||||
static_assert(ArrayType::size() == size(), "wrong! size not the same");
|
||||
for(index_t i = 0; i < size(); ++i)
|
||||
{
|
||||
data[i] = a[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
#endif
|
||||
// type punning (strict aliasing) member functions for read/write
|
||||
// aliasing this array of type "T", "N" elements
|
||||
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
|
||||
@@ -122,6 +135,17 @@ struct array<T, 0>
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<array<T, N>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs)
|
||||
{
|
||||
|
||||
@@ -468,6 +468,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
number<Seq::size()>{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, a_size, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
@@ -479,5 +480,21 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#else
|
||||
// constexpr index_t can't be captured "-Wunused-lambda-capture"
|
||||
// TODO: this is ugly
|
||||
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
|
||||
[a_of_b_impl, bs_sizes] { \
|
||||
return ck_tile::generate_tuple( \
|
||||
[=](auto i) { \
|
||||
constexpr auto b_impl = a_of_b_impl[i]; \
|
||||
constexpr index_t b_size = bs_sizes[i]; \
|
||||
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
|
||||
return b; \
|
||||
}, \
|
||||
ck_tile::number<a_size>{}); \
|
||||
}()
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -67,13 +67,12 @@ struct sequence
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get(number<I>)
|
||||
{
|
||||
static_assert(I < size(), "wrong! I too large");
|
||||
return number<get(I)>{};
|
||||
return number<get<I>()>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
static_assert(I < size(), "wrong! I too large");
|
||||
const index_t mData[size() + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
@@ -89,7 +88,7 @@ struct sequence
|
||||
CK_TILE_HOST_DEVICE static constexpr auto at(number<I>)
|
||||
{
|
||||
static_assert(I < size(), "wrong! I too large");
|
||||
return number<get(I)>{};
|
||||
return number<get<I>()>{};
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
|
||||
@@ -16,40 +16,48 @@ namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
|
||||
// the place where content is stored
|
||||
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
|
||||
struct tuple_element
|
||||
struct tuple_object
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_element<idx, T, true>
|
||||
struct tuple_object<idx, T, true>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {}
|
||||
};
|
||||
|
||||
template <index_t idx, typename T>
|
||||
struct tuple_element<idx, T, false>
|
||||
struct tuple_object<idx, T, false>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element() {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {}
|
||||
T element;
|
||||
};
|
||||
|
||||
// NOTE: we return a instance(not a reference) if content is empty
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element<I, T, false> const& x)
|
||||
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element<I, T, false>& x)
|
||||
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
|
||||
{
|
||||
return x.element;
|
||||
}
|
||||
|
||||
template <std::size_t I, class T>
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element<I, T, false>&& x)
|
||||
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
|
||||
{
|
||||
return static_cast<T&&>(x.element);
|
||||
}
|
||||
@@ -58,18 +66,18 @@ template <typename index_seq, typename... T>
|
||||
struct tuple_base;
|
||||
|
||||
template <index_t... I, typename... T>
|
||||
struct tuple_base<sequence<I...>, T...> : public tuple_element<I, T>...
|
||||
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base() {}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element<I, T>(u)...
|
||||
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...> const& u)
|
||||
: tuple_element<I, T>(getv(static_cast<tuple_element<I, U> const&>(u)))...
|
||||
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
|
||||
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -84,15 +92,13 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple() {}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(U const&... u)
|
||||
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(u...)
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... U>
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...> const& u)
|
||||
: impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>(
|
||||
static_cast<impl::tuple_base<U...> const&>(u))
|
||||
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
|
||||
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -109,19 +115,19 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr const auto & operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr auto & operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
|
||||
// clang-format on
|
||||
#undef TP_COM_
|
||||
};
|
||||
@@ -250,15 +256,15 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element)
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
{
|
||||
return element;
|
||||
return t;
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element)
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
|
||||
{
|
||||
return make_tuple(element);
|
||||
return make_tuple(t);
|
||||
}
|
||||
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
||||
@@ -334,7 +340,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
index_t max_n1_ = 0;
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size()();
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
|
||||
});
|
||||
@@ -345,7 +351,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
|
||||
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
|
||||
|
||||
static_for<0, n0, 1>{}([&](auto i0) {
|
||||
constexpr index_t n1 = t_of_s[i0].size()();
|
||||
constexpr index_t n1 = t_of_s[i0].size();
|
||||
|
||||
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
|
||||
});
|
||||
@@ -482,3 +488,60 @@ struct tuple_element<I, const ck_tile::tuple<Ts...>>
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#if 1
|
||||
#define TO_TUPLE_OF_NUMBER(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
|
||||
} \
|
||||
(ck_tile::make_index_sequence<n>{}) \
|
||||
_Pragma("clang diagnostic pop")
|
||||
#else
|
||||
#define TO_TUPLE_OF_NUMBER(arr, n_) \
|
||||
[&arr, n_] { \
|
||||
static_assert(arr.size() >= n_, "wrong! out of bound"); \
|
||||
\
|
||||
static_assert(n_ < 7, "not implemented"); \
|
||||
\
|
||||
if constexpr(n_ == 0) \
|
||||
{ \
|
||||
return ck_tile::tuple<>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 1) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 2) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 3) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 4) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 5) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>>{}; \
|
||||
} \
|
||||
else if constexpr(n_ == 6) \
|
||||
{ \
|
||||
return ck_tile::tuple<number<arr[0]>, \
|
||||
number<arr[1]>, \
|
||||
number<arr[2]>, \
|
||||
number<arr[3]>, \
|
||||
number<arr[4]>, \
|
||||
number<arr[5]>>{}; \
|
||||
} \
|
||||
}()
|
||||
#endif
|
||||
|
||||
@@ -20,7 +20,8 @@ enum class bf16_rounding_mode
|
||||
truncate,
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -41,36 +42,42 @@ struct alignas(2) bfloat16_t
|
||||
}
|
||||
|
||||
// constructor
|
||||
bfloat16_t() = default;
|
||||
constexpr bfloat16_t() : data() {}
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); }
|
||||
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast<float>(x)); }
|
||||
explicit constexpr bfloat16_t(const unsigned int& x)
|
||||
: data(float_to_bf16_raw(static_cast<float>(x)))
|
||||
{
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator float() const { return bf16_to_float_raw(data); }
|
||||
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type& get() { return data; }
|
||||
constexpr raw_type& get() { return data; }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
raw_type get() const { return data; }
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_rtn_raw(float f)
|
||||
@@ -139,8 +146,8 @@ uint16_t float_to_bf16_truc_raw(float f)
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {})
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
@@ -161,8 +168,9 @@ float bf16_to_float_raw(uint16_t x)
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
@@ -170,14 +178,15 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding = CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT>
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
|
||||
half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
@@ -259,6 +268,4 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using bf16_t = bfloat16_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include <stdint.h>
|
||||
@@ -75,20 +76,19 @@ struct alignas(1) float8_e4m3_t
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); }
|
||||
explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const int& x)
|
||||
explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e4m3_t(const unsigned int& x)
|
||||
: data(float_to_fp8_raw(static_cast<float>(x)))
|
||||
{
|
||||
data = float_to_fp8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
@@ -106,6 +106,8 @@ struct alignas(1) float8_e4m3_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
using fp8_t = float8_e4m3_t;
|
||||
using fp8_raw_t = typename fp8_t::raw_type;
|
||||
|
||||
struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
@@ -132,25 +134,24 @@ struct alignas(1) float8_e5m2_t
|
||||
|
||||
// construct from float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); }
|
||||
explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
|
||||
|
||||
// construct from int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const int& x)
|
||||
explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// construct from unsigned int
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr float8_e5m2_t(const unsigned int& x)
|
||||
: data(float_to_bf8_raw(static_cast<float>(x)))
|
||||
{
|
||||
data = float_to_bf8_raw(static_cast<float>(x));
|
||||
}
|
||||
|
||||
// cast to float
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr constexpr operator float() const { return bf8_to_float_raw(data); }
|
||||
explicit constexpr operator float() const { return bf8_to_float_raw(data); }
|
||||
|
||||
// cast to int
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -163,6 +164,8 @@ struct alignas(1) float8_e5m2_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
using bf8_t = float8_e5m2_t;
|
||||
using bf8_raw_t = typename bf8_t::raw_type;
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
namespace impl {
|
||||
@@ -431,10 +434,10 @@ CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
@@ -453,16 +456,18 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
|
||||
fp8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
{
|
||||
@@ -479,13 +484,15 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
|
||||
bf8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float max_fp8 = 240.0f;
|
||||
@@ -506,12 +513,14 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x)
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, fp8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
|
||||
fp8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
union
|
||||
@@ -530,30 +539,32 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x)
|
||||
constexpr bool clip = true;
|
||||
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
|
||||
constexpr uint32_t rng = 0;
|
||||
return impl::
|
||||
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == fp8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
|
||||
bf8_t,
|
||||
negative_zero_nan,
|
||||
clip,
|
||||
(rm == fp8_rounding_mode::stochastic)>(x, rng));
|
||||
#endif
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant<rounding> = {})
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
else return fp8_raw_t{0};
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant<rounding> = {})
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return uint8_t{0};
|
||||
else return bf8_raw_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -563,11 +574,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(x);
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(fp8_t::bit_cast(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -577,18 +588,18 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bf8_t::bit_cast(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding> = {})
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding>)
|
||||
{
|
||||
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding> = {})
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding>)
|
||||
{
|
||||
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
@@ -604,8 +615,6 @@ CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
using fp8_t = float8_e4m3_t;
|
||||
using bf8_t = float8_e5m2_t;
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
@@ -76,6 +76,9 @@ struct alignas(2) half_t
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
using fp16_t = half_t;
|
||||
using fp16_raw_t = typename fp16_t::raw_type;
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
@@ -282,6 +285,4 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using fp16_t = half_t;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -147,6 +147,9 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T&
|
||||
return min(max(x, lowerbound), upperbound);
|
||||
}
|
||||
|
||||
CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
@@ -222,7 +225,7 @@ struct less
|
||||
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
||||
return 1 << (32 - __builtin_clz(x - 1));
|
||||
return 1 << (32 - clz(x - 1));
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
@@ -243,7 +246,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
|
||||
{
|
||||
// TODO: x need to be 1 ~ 0x7fffffff
|
||||
// __builtin_clz will produce unexpected result if x is 0;
|
||||
return 31 - __builtin_clz(x);
|
||||
return 31 - clz(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
|
||||
|
||||
@@ -3,10 +3,22 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
|
||||
{
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
@@ -15,11 +27,9 @@ template <typename Y,
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// TODO: const version never called, we may never need
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
@@ -28,18 +38,29 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using NonConstY = std::remove_const_t<Y>;
|
||||
using NonConstX = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
#else
|
||||
// compatible way to call conversion operator and constructor of each custom data type
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, fp8_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf8_t)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,295 +10,112 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: the whole content of this file should consider deprecated!
|
||||
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
|
||||
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
// have compiler error
|
||||
namespace impl {
|
||||
template <typename T_, index_t N_>
|
||||
struct vector_type
|
||||
struct ext_vector
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = T_;
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
|
||||
CK_HOST_DEVICE constexpr vector_type()
|
||||
{
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = static_cast<value_type>(0);
|
||||
}
|
||||
CK_HOST_DEVICE constexpr vector_type(type v)
|
||||
{
|
||||
auto& r = reinterpret_cast<const array<value_type, N>&>(v);
|
||||
for(auto i = 0; i < N; i++)
|
||||
data[i] = r.get(i);
|
||||
}
|
||||
|
||||
value_type data[N];
|
||||
CK_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_HOST_DEVICE auto& get() { return data; }
|
||||
CK_HOST_DEVICE const auto& get() const { return data; }
|
||||
CK_HOST_DEVICE auto& get(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; }
|
||||
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator[](number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& operator[](number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& operator()(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
CK_HOST_DEVICE auto& at(index_t i) { return data[i]; }
|
||||
CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; }
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at()
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at() const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE auto& at(number<I>)
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
template <index_t I>
|
||||
CK_HOST_DEVICE const auto& at(number<I>) const
|
||||
{
|
||||
return data[I];
|
||||
}
|
||||
|
||||
#define _VT_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as()
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as() const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE auto& get_as(index_t i)
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
template <typename Tx>
|
||||
CK_HOST_DEVICE const auto& get_as(index_t i) const
|
||||
{
|
||||
_VT_COMMON_AS();
|
||||
return reinterpret_cast<const array<Tx, vx>&>(data).get(i);
|
||||
}
|
||||
#undef _VT_COMMON_AS
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_type_maker
|
||||
using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T>
|
||||
struct vector_traits
|
||||
{
|
||||
using type = vector_type<T, N>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
using scalar_type = remove_cvref_t<T>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// specialization for ext_vector_type()
|
||||
template <typename T, index_t N>
|
||||
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_HOST_DEVICE constexpr auto make_vector_type(number<N>)
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
return typename vector_type_maker<T, N>::type{};
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
// is_scalar_type
|
||||
template <typename TV>
|
||||
struct is_scalar_type
|
||||
{
|
||||
static constexpr bool value = (scalar_type<remove_cvref_t<TV>>::vector_size == 1);
|
||||
};
|
||||
|
||||
// has_same_scalar_type
|
||||
template <typename X, typename Y>
|
||||
using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<Y>>::type>;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using type = T;
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<vector_type<T, N>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
//
|
||||
template <>
|
||||
struct scalar_type<double>
|
||||
{
|
||||
using type = double;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<float>
|
||||
{
|
||||
using type = float;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<half_t>
|
||||
{
|
||||
using type = half_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bhalf_t>
|
||||
{
|
||||
using type = bhalf_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int64_t>
|
||||
{
|
||||
using type = int64_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int32_t>
|
||||
{
|
||||
using type = int32_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int8_t>
|
||||
{
|
||||
using type = int8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct scalar_type<int4_t>
|
||||
{
|
||||
using type = int4_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct scalar_type<fp8_t>
|
||||
{
|
||||
using type = fp8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf8_t>
|
||||
{
|
||||
using type = bf8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
using double2_t = typename vector_type<double, 2>::type;
|
||||
using double4_t = typename vector_type<double, 4>::type;
|
||||
using fp64x2_t = double __attribute__((ext_vector_type(2)));
|
||||
using fp64x4_t = double __attribute__((ext_vector_type(4)));
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float16_t = typename vector_type<float, 16>::type;
|
||||
using float32_t = typename vector_type<float, 32>::type;
|
||||
using float64_t = typename vector_type<float, 64>::type;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp32x4_t = float __attribute__((ext_vector_type(4)));
|
||||
using fp32x8_t = float __attribute__((ext_vector_type(8)));
|
||||
using fp32x16_t = float __attribute__((ext_vector_type(16)));
|
||||
using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
using half32_t = typename vector_type<half_t, 32>::type;
|
||||
using half64_t = typename vector_type<half_t, 64>::type;
|
||||
using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bfp16
|
||||
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
|
||||
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
|
||||
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
|
||||
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
|
||||
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
|
||||
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
using int32x64_t = typename vector_type<int32_t, 64>::type;
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
|
||||
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
|
||||
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
|
||||
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
|
||||
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// f8
|
||||
using fp8x2_t = typename vector_type<fp8_t, 2>::type;
|
||||
using fp8x4_t = typename vector_type<fp8_t, 4>::type;
|
||||
using fp8x8_t = typename vector_type<fp8_t, 8>::type;
|
||||
using fp8x16_t = typename vector_type<fp8_t, 16>::type;
|
||||
using fp8x32_t = typename vector_type<fp8_t, 32>::type;
|
||||
using fp8x64_t = typename vector_type<fp8_t, 64>::type;
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
|
||||
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
|
||||
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/amd_address_space.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -22,13 +23,13 @@ namespace ck_tile {
|
||||
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
|
||||
// transforms of tensor_view/Tensor
|
||||
// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
|
||||
// BufferView definition for different memory address space (Global/GenericLds/Vgpr)
|
||||
// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
|
||||
template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
struct BufferView;
|
||||
struct buffer_view;
|
||||
|
||||
// Address Space: generic
|
||||
// T may be scalar or vector
|
||||
@@ -82,17 +83,18 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -123,19 +125,20 @@ struct buffer_view<address_space_enum::generic,
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove InMemoryDataOperationEnum::Add
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
@@ -144,15 +147,16 @@ struct buffer_view<address_space_enum::generic,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -253,17 +257,18 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -326,16 +331,17 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
|
||||
{
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -348,15 +354,16 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
|
||||
{
|
||||
// X is vector of T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -368,27 +375,28 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::atomic_add)
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::atomic_max)
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
this->template atomic_max<X>(i, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove InMemoryDataOperationEnum::Add
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
@@ -399,16 +407,17 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -443,16 +452,17 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -463,17 +473,18 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -482,15 +493,16 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -512,15 +524,16 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -528,8 +541,8 @@ struct buffer_view<address_space_enum::global,
|
||||
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
|
||||
|
||||
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -628,17 +641,18 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -669,19 +683,20 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove InMemoryDataOperationEnum::Add
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
@@ -690,15 +705,16 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -709,7 +725,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
bool constexpr workaround_int8_ds_write_issue = false;
|
||||
#endif
|
||||
|
||||
if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
|
||||
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
int8_t>::value &&
|
||||
workaround_int8_ds_write_issue)
|
||||
{
|
||||
if(is_valid_element)
|
||||
@@ -718,83 +735,83 @@ struct buffer_view<address_space_enum::lds,
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x2_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value),
|
||||
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8_t>::value)
|
||||
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x2_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
@@ -899,17 +916,18 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -940,19 +958,20 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::set)
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove InMemoryDataOperationEnum::Add
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
@@ -961,15 +980,16 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
typename std::enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -1029,7 +1049,7 @@ template <address_space_enum BufferAddressSpace,
|
||||
typename T,
|
||||
typename BufferSizeType,
|
||||
typename X,
|
||||
typename std::enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
|
||||
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -65,13 +68,13 @@ CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile(const NullTileWindow<WindowLengths>&)
|
||||
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
return NullTensor{};
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
template <typename T, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow<WindowLengths>&)
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace detail {
|
||||
@@ -74,8 +76,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
constexpr index_t num_vec_in = vec_length_out;
|
||||
constexpr index_t num_vec_out = vec_length_in;
|
||||
|
||||
using InVec = vector_type<DataType, vec_length_in>;
|
||||
using OutVec = vector_type<DataType, vec_length_out>;
|
||||
using InVec = array<DataType, vec_length_in>;
|
||||
using OutVec = array<DataType, vec_length_out>;
|
||||
|
||||
using InVecType = typename InVec::type;
|
||||
using OutVecType = typename OutVec::type;
|
||||
@@ -114,7 +116,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
|
||||
|
||||
in_vectors(i).template AsType<InVecType>()(I0) =
|
||||
in_vectors(i).template get_as<InVecType>()(I0) =
|
||||
in_tensor.get_thread_buffer().template get_as<InVecType>(number<in_offset>{});
|
||||
});
|
||||
|
||||
@@ -134,7 +136,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
out_tensor.get_thread_buffer().template set_as<OutVecType>(
|
||||
number<out_offset / sizeof(OutVecType)>{},
|
||||
out_vectors[i].template AsType<OutVecType>()[I0]);
|
||||
out_vectors[i].template get_as<OutVecType>()[I0]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
@@ -48,7 +48,7 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
|
||||
@@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -725,7 +725,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
|
||||
STATIC_ASSERT(name == cood_transform_enum::pass_through || \
|
||||
name == cood_transform_enum::pad || \
|
||||
name == cood_transform_enum::embed || \
|
||||
name == cood_transform_enum::merge || \
|
||||
@@ -733,7 +733,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
name == cood_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == cood_transform_enum::PassThrough) \
|
||||
if constexpr(name == cood_transform_enum::pass_through) \
|
||||
{ \
|
||||
index_t pos = 0; \
|
||||
auto low_len = meta_data.template pop<index_t>(pos); \
|
||||
@@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
|
||||
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
|
||||
\
|
||||
constexpr auto trans = [&encoded_transforms, &num_transform]() { \
|
||||
constexpr auto trans = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) constexpr { \
|
||||
constexpr auto name = encoded_transforms[i].template at<0>(); \
|
||||
@@ -849,7 +849,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
\
|
||||
STATIC_ASSERT(name == cood_transform_enum::PassThrough || \
|
||||
STATIC_ASSERT(name == cood_transform_enum::pass_through || \
|
||||
name == cood_transform_enum::pad || \
|
||||
name == cood_transform_enum::embed || \
|
||||
name == cood_transform_enum::merge || \
|
||||
@@ -857,7 +857,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
name == cood_transform_enum::replicate, \
|
||||
""); \
|
||||
\
|
||||
if constexpr(name == cood_transform_enum::PassThrough) \
|
||||
if constexpr(name == cood_transform_enum::pass_through) \
|
||||
{ \
|
||||
constexpr index_t low_len = meta_data.template get<index_t>(0); \
|
||||
\
|
||||
@@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>{}); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
|
||||
constexpr auto low_dim_idss = [&encoded_transforms]() { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
|
||||
@@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&..
|
||||
number<num_transform>()); \
|
||||
}(); \
|
||||
\
|
||||
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
|
||||
constexpr auto up_dim_idss = [&encoded_transforms] { \
|
||||
return generate_tuple( \
|
||||
[&encoded_transforms](auto i) { \
|
||||
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
|
||||
|
||||
@@ -299,7 +299,7 @@ template <typename... Lengths,
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const tuple<Strides...>& strides,
|
||||
const offset& offset,
|
||||
const offset& os,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
|
||||
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
|
||||
{
|
||||
@@ -307,7 +307,7 @@ make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
|
||||
const auto element_space_size = detail::calculate_element_space_size_impl(
|
||||
lengths, strides, number<0>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
|
||||
|
||||
@@ -383,12 +383,12 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename offset,
|
||||
typename Offset,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
|
||||
const tuple<Lengths...>& lengths,
|
||||
const offset& offset,
|
||||
const Offset& offset,
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -16,15 +19,16 @@ namespace ck_tile {
|
||||
template <typename BufferView_, typename TensorDesc_>
|
||||
struct tensor_view
|
||||
{
|
||||
using BufferView = remove_reference_t<BufferView_>;
|
||||
using DataType = typename BufferView::type;
|
||||
using buffer_view = remove_reference_t<BufferView_>;
|
||||
using DataType = typename buffer_view::type;
|
||||
using TensorDesc = remove_cvref_t<TensorDesc_>;
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc)
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
|
||||
const TensorDesc& desc)
|
||||
: buf_{buffer_view}, desc_{desc}
|
||||
{
|
||||
}
|
||||
@@ -58,12 +62,12 @@ struct tensor_view
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -76,12 +80,12 @@ struct tensor_view
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
@@ -93,11 +97,11 @@ struct tensor_view
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <
|
||||
typename X,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord) const
|
||||
{
|
||||
@@ -106,12 +110,12 @@ struct tensor_view
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -121,12 +125,12 @@ struct tensor_view
|
||||
x);
|
||||
}
|
||||
|
||||
template <
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<is_same_v<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<DataType>>::type>,
|
||||
bool>::type = false>
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -153,7 +157,7 @@ struct tensor_view
|
||||
}
|
||||
|
||||
// member
|
||||
BufferView buf_;
|
||||
buffer_view buf_;
|
||||
TensorDesc desc_;
|
||||
};
|
||||
|
||||
@@ -162,7 +166,7 @@ struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
@@ -173,7 +177,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
typename... Strides,
|
||||
@@ -197,7 +201,7 @@ make_naive_tensor_view(DataType* p,
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace = AddressSpaceEnum::Generic,
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
typename DataType,
|
||||
typename... Lengths,
|
||||
index_t GuaranteedLastDimensionVectorLength = -1>
|
||||
@@ -228,19 +232,19 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
|
||||
NewLowerDimensionOldVisibleIdss{},
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
return tensor_view<typename OldTensorView::BufferView, remove_cvref_t<decltype(new_desc)>>{
|
||||
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
|
||||
old_tensor_view.buf_, new_desc};
|
||||
}
|
||||
|
||||
template <typename tensor_view,
|
||||
template <typename TensorView,
|
||||
typename TileLengths, // tuple<...>
|
||||
typename DoPads> // sequence<bool, bool, ...>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
constexpr index_t num_dim = DoPads::size();
|
||||
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(),
|
||||
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
// transforms
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -293,7 +295,9 @@ CK_TILE_HOST_DEVICE constexpr auto
|
||||
&hidden_dim_cnt,
|
||||
&rh_major_minor_to_hidden_ids,
|
||||
&rh_major_minor_to_hidden_lengths](auto idim_x) {
|
||||
constexpr auto h_minor_lengths = tuple_element_t<idim_x, HsLengthss>{};
|
||||
// typename HsLengthss::base{}.foo();
|
||||
constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t<idim_x, HsLengthss>{};
|
||||
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
|
||||
|
||||
constexpr index_t ndim_h_minor = h_minor_lengths.size();
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -18,7 +19,7 @@ namespace ck_tile {
|
||||
template <typename InOutElementFunc,
|
||||
typename... InOutDstrTensors,
|
||||
typename = std::enable_if_t<std::conjunction_v<
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, NullTensor>>...>>>
|
||||
std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
|
||||
InOutDstrTensors&... inout_dstr_tensors)
|
||||
{
|
||||
@@ -26,7 +27,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
|
||||
// static_assert(xxx);
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
__type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}(
|
||||
[&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
|
||||
@@ -35,7 +36,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element
|
||||
template <typename InElementFunc,
|
||||
typename... InDstrTensors,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::negation<std::is_same<InDstrTensors, NullTensor>>...>>>
|
||||
std::conjunction_v<std::negation<std::is_same<InDstrTensors, null_tensor>>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InDstrTensors&... in_dstr_tensors)
|
||||
{
|
||||
@@ -43,10 +44,10 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution();
|
||||
constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size =
|
||||
type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
|
||||
__type_pack_element<0, InDstrTensors...>::get_thread_buffer_size();
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
@@ -69,7 +70,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void set_tile(NullTensor&, const T&)
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -82,7 +83,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
|
||||
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
|
||||
if constexpr(v == 0 && tensor_bytes % 4 == 0)
|
||||
{
|
||||
using dvec_t = static_buffer_c<index_t, tensor_bytes / 4>;
|
||||
using dvec_t = array<index_t, tensor_bytes / 4>;
|
||||
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
|
||||
for(auto i = 0; i < tensor.size(); i++)
|
||||
tensor.get(i) = v;
|
||||
@@ -96,7 +97,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
|
||||
}
|
||||
|
||||
template <index_t v>
|
||||
CK_TILE_DEVICE void set_tile(NullTensor&, number<v>)
|
||||
CK_TILE_DEVICE void set_tile(null_tensor&, number<v>)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -139,7 +140,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
|
||||
false); // false -> WORD0
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
using vec_t = typename vector_type<OutDataType, 4>::type;
|
||||
using vec_t = array<OutDataType, 4>;
|
||||
|
||||
vec_t d = bit_cast<vec_t>(__builtin_amdgcn_perm(y, x, m0));
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
|
||||
@@ -157,9 +158,9 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors)
|
||||
template <typename DstType, typename SrcDstrTensors>
|
||||
CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
|
||||
{
|
||||
if constexpr((ck_tile::is_same_v<DstType, fp8_t> ||
|
||||
ck_tile::is_same_v<DstType, bf8_t>)&&ck_tile::
|
||||
is_same_v<typename SrcDstrTensors::DataType, float> &&
|
||||
if constexpr((std::is_same_v<DstType, fp8_t> ||
|
||||
std::is_same_v<DstType, bf8_t>)&&std::is_same_v<typename SrcDstrTensors::DataType,
|
||||
float> &&
|
||||
(SrcDstrTensors::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return cast_tile_pk_fp8x4<DstType, SrcDstrTensors>(src_tensor);
|
||||
@@ -169,23 +170,23 @@ CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor)
|
||||
src_tensor);
|
||||
}
|
||||
|
||||
// no-op function for NullTensor arguments
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InOutElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
}
|
||||
|
||||
// no-op function for NullTensor arguments
|
||||
// no-op function for null_tensor arguments
|
||||
template <typename InElementFunc,
|
||||
typename... MaybeNullTensor,
|
||||
typename = std::enable_if_t<
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, NullTensor>...>>>
|
||||
std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
|
||||
{
|
||||
return NullTensor{};
|
||||
return null_tensor{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -85,8 +88,9 @@ struct tile_window_with_static_distribution
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = array<DataType, ScalarPerVector>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
@@ -275,9 +279,8 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
@@ -300,8 +303,6 @@ struct tile_window_with_static_distribution
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
|
||||
const vector_type_t vec{vec_value};
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
@@ -315,7 +316,7 @@ struct tile_window_with_static_distribution
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec.template AsType<DataType>()[j];
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
});
|
||||
|
||||
// move thread coordinate
|
||||
@@ -341,16 +342,17 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % Traits::ScalarPerVector == 0);
|
||||
using vectorized_tbuf = StaticBuffer<address_space_enum::vgpr,
|
||||
vector_t,
|
||||
YElementSize / Traits::ScalarPerVector,
|
||||
true>;
|
||||
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
|
||||
// StaticBuffer<address_space_enum::vgpr,
|
||||
// vector_t,
|
||||
// YElementSize / Traits::ScalarPerVector,
|
||||
// true>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
@@ -426,9 +428,9 @@ struct tile_window_with_static_distribution
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
@@ -468,9 +470,9 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
@@ -487,7 +489,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_type_t vec;
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
@@ -500,11 +503,11 @@ struct tile_window_with_static_distribution
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec.template AsType<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
@@ -530,9 +533,9 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename vector_type_t::type;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
@@ -550,7 +553,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_type_t vec;
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
@@ -563,11 +567,11 @@ struct tile_window_with_static_distribution
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec.template AsType<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
const vector_t vec_value = vec.template AsType<vector_t>().template at<0>();
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
|
||||
@@ -191,4 +191,18 @@ CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
// z = predicate ? x : y
|
||||
template <bool predicate, typename X, typename Y>
|
||||
constexpr auto conditional_expr(X&& x, Y&& y)
|
||||
{
|
||||
if constexpr(predicate)
|
||||
{
|
||||
return std::forward<X>(x);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::forward<Y>(y);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,12 +9,13 @@
|
||||
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
|
||||
#define TO_SEQUENCE(a, n) \
|
||||
_Pragma("clang diagnostic push") \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... Is>( \
|
||||
ck_tile::sequence<Is...>) \
|
||||
_Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \
|
||||
[a]<ck_tile::index_t... IDX_IDX_>(ck_tile::sequence<IDX_IDX_...>) \
|
||||
{ \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<Is>{})...>{}; \
|
||||
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
|
||||
} \
|
||||
(make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
|
||||
(ck_tile::make_index_sequence<n>{}); \
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
#else
|
||||
// Macro function
|
||||
|
||||
123
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
123
include/ck_tile/core/utility/transpose_vectors.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S: scalar type (or it can be non-scalar type)
|
||||
// NX: # of vector before transpose
|
||||
// NY: # of vector after transpose
|
||||
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
|
||||
template <typename S_, index_t NX, index_t NY>
|
||||
struct transpose_vectors
|
||||
{
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = remove_cvref_t<S_>;
|
||||
|
||||
using VX = array<S, s_per_x>;
|
||||
using VY = array<S, s_per_y>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const array<VX, NX>& vx_tuple, array<VY, NY>& vy_tuple)
|
||||
{
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
constexpr auto I4 = number<4>{};
|
||||
|
||||
if constexpr(sizeof(S) == 2)
|
||||
{
|
||||
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");
|
||||
|
||||
using S2 = array<S, 2>; // typename array<S, 2>::type;
|
||||
|
||||
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 2>{}([&](auto iy) {
|
||||
static_for<0, NX, 2>{}([&](auto ix) {
|
||||
// 2 16bitx2 data from vx_tuple to be transposed
|
||||
const int32_t x_s2_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S2>()[iy / I2]);
|
||||
const int32_t x_s2_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S2>()[iy / I2]);
|
||||
|
||||
constexpr int32_t m0 = 0x05040100;
|
||||
constexpr int32_t m1 = 0x07060302;
|
||||
|
||||
// transpose 2x2 16bit
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0);
|
||||
const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1);
|
||||
|
||||
// 2 16bitx2 data after transposed
|
||||
vy_tuple(iy).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_0);
|
||||
vy_tuple(iy + I1).template get_as<S2>()(ix / I2) = bit_cast<S2>(y_s2_1);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(sizeof(S) == 1)
|
||||
{
|
||||
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
|
||||
|
||||
using S4 = array<S, 4>; // typename array<S, 4>::type;
|
||||
|
||||
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// 4 int8x4 data from vx_tuple
|
||||
const int32_t x_s4_0 =
|
||||
bit_cast<int32_t>(vx_tuple[ix].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_1 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I1].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_2 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I2].template get_as<S4>()[iy / I4]);
|
||||
const int32_t x_s4_3 =
|
||||
bit_cast<int32_t>(vx_tuple[ix + I3].template get_as<S4>()[iy / I4]);
|
||||
|
||||
// transpose
|
||||
int32_t t_s4_0, t_s4_1;
|
||||
int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3;
|
||||
|
||||
constexpr int32_t m0 = 0x05010400;
|
||||
constexpr int32_t m1 = 0x05040100;
|
||||
constexpr int32_t m2 = 0x07060302;
|
||||
constexpr int32_t m3 = 0x07030602;
|
||||
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0);
|
||||
y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3);
|
||||
t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3);
|
||||
y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1);
|
||||
y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2);
|
||||
|
||||
// 4 int8x4 data from vy_tuple
|
||||
vy_tuple(iy).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_0);
|
||||
vy_tuple(iy + I1).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_1);
|
||||
vy_tuple(iy + I2).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_2);
|
||||
vy_tuple(iy + I3).template get_as<S4>()(ix / I4) = bit_cast<S4>(y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,57 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using non_const_y = std::remove_const_t<Y>;
|
||||
using non_const_x = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
template <> \
|
||||
inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, fp8_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf8_t)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, float)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -69,4 +69,18 @@ struct nonesuch
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
// FIXME: do we need this anymore?
|
||||
template <
|
||||
typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<std::is_pointer_v<PY> && std::is_pointer_v<PX>, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -97,7 +97,7 @@ check_err(const Range& out,
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
|
||||
bool>::type
|
||||
check_err(const Range& out,
|
||||
const RefRange& ref,
|
||||
@@ -123,7 +123,7 @@ check_err(const Range& out,
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
|
||||
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
|
||||
double max_err = std::numeric_limits<float>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
@@ -214,7 +214,7 @@ check_err(const Range& out,
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_integral_v<ranges::range_value_t<Range>> &&
|
||||
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
|
||||
!std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
|
||||
#endif
|
||||
|
||||
@@ -147,8 +147,8 @@ float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int MaxThreadPerBlock = CK_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_MIN_BLOCK_PER_CU,
|
||||
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
float launch_kernel(const stream_config& s,
|
||||
|
||||
4
include/ck_tile/ops/common/README.md
Normal file
4
include/ck_tile/ops/common/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
## common
|
||||
this folder is designed not to be included directly by use, e.g. if use include `ck_tile/ops/fmha.hpp`, then everything under `common` should also be included.
|
||||
|
||||
to achieve this we will duplicate the header include path under `common` to other module under `ops/*` inside remod.py. for internal developer, you can also include `ck_tile/ops/common.hpp` for convenience. (and so does external users...)
|
||||
@@ -4,4 +4,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -17,4 +17,5 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -50,7 +51,7 @@ struct FmhaFwdKernel
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::half_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bhalf_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
@@ -79,7 +80,7 @@ struct FmhaFwdKernel
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -407,7 +408,7 @@ struct FmhaFwdKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
@@ -519,7 +520,7 @@ struct FmhaFwdKernel
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
|
||||
@@ -49,12 +49,12 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kIsFp8 =
|
||||
(is_same_v<QDataType, fp8_t> || is_same_v<QDataType, bf8_t>)&&(
|
||||
is_same_v<KDataType, fp8_t> ||
|
||||
is_same_v<KDataType, bf8_t>)&&(is_same_v<VDataType, fp8_t> ||
|
||||
is_same_v<VDataType, bf8_t>)&&is_same_v<SaccDataType,
|
||||
float> &&
|
||||
is_same_v<OaccDataType, float>;
|
||||
(std::is_same_v<QDataType, fp8_t> || std::is_same_v<QDataType, bf8_t>)&&(
|
||||
std::is_same_v<KDataType, fp8_t> ||
|
||||
std::is_same_v<KDataType, bf8_t>)&&(std::is_same_v<VDataType, fp8_t> ||
|
||||
std::is_same_v<VDataType, bf8_t>)&&std::
|
||||
is_same_v<SaccDataType, float> &&
|
||||
std::is_same_v<OaccDataType, float>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -56,7 +56,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
@@ -127,9 +127,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -442,7 +442,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
@@ -471,8 +471,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -59,7 +59,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
@@ -138,9 +138,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -415,7 +415,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
@@ -539,8 +539,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
|
||||
if constexpr(ck_tile::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -56,7 +56,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
@@ -119,9 +119,9 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -425,7 +425,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
@@ -453,8 +453,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -114,9 +114,9 @@ struct BlockFmhaPipelineQSKSVS
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -434,7 +434,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
@@ -463,8 +463,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(ck_tile::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
|
||||
@@ -76,24 +80,24 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
Problem::BlockFmhaShape::kK0>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(is_same_v<typename Problem::QDataType, half_t> &&
|
||||
is_same_v<typename Problem::KDataType, half_t> &&
|
||||
is_same_v<typename Problem::SaccDataType, float>)
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(is_same_v<typename Problem::QDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::KDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::SaccDataType, float>)
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return warp::WarpGemmImpl<
|
||||
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
@@ -201,24 +205,24 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
Problem::BlockFmhaShape::kK0>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(is_same_v<typename Problem::QDataType, half_t> &&
|
||||
is_same_v<typename Problem::KDataType, half_t> &&
|
||||
is_same_v<typename Problem::SaccDataType, float>)
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(is_same_v<typename Problem::QDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::KDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::SaccDataType, float>)
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{};
|
||||
}
|
||||
else if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
constexpr index_t swizzle_factor = 4; // TODO: hard coded here
|
||||
return warp::WarpGemmImpl<
|
||||
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType>,
|
||||
2,
|
||||
@@ -337,7 +341,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
@@ -762,7 +766,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
if constexpr(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
@@ -857,7 +861,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
{
|
||||
// This descriptor only used when V layout is seqlen * hdim
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
static_assert(ck_tile::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
@@ -914,15 +918,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(Problem::kIsFp8)
|
||||
{
|
||||
return warp::WarpGemmImpl<
|
||||
warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
return WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType>,
|
||||
2>>{};
|
||||
// return
|
||||
// warp::WarpGemmImpl<warp::WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
}
|
||||
else
|
||||
|
||||
@@ -28,3 +28,4 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -44,9 +44,9 @@ struct BlockGemmARegBGmemCRegV1
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
|
||||
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
|
||||
@@ -90,9 +90,10 @@ struct BlockGemmARegBGmemCRegV1
|
||||
const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
|
||||
@@ -28,10 +28,11 @@ struct BlockGemmARegBSmemCRegV1
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
@@ -126,9 +127,9 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
@@ -184,9 +185,10 @@ struct BlockGemmARegBSmemCRegV1
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
|
||||
@@ -14,9 +14,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(is_same_v<typename Problem::ADataType, half_t> &&
|
||||
is_same_v<typename Problem::BDataType, half_t> &&
|
||||
is_same_v<typename Problem::CDataType, float>)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -43,9 +43,9 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<typename Problem::ADataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::BDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::CDataType, float>)
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -28,10 +28,11 @@ struct BlockGemmARegBSmemCRegV2
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
@@ -126,9 +127,9 @@ struct BlockGemmARegBSmemCRegV2
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
|
||||
@@ -28,9 +28,9 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
|
||||
@@ -14,9 +14,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(is_same_v<typename Problem::ADataType, half_t> &&
|
||||
is_same_v<typename Problem::BDataType, half_t> &&
|
||||
is_same_v<typename Problem::CDataType, float>)
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -42,9 +42,9 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<typename Problem::ADataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::BDataType, bhalf_t> &&
|
||||
is_same_v<typename Problem::CDataType, float>)
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
|
||||
@@ -47,8 +47,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -74,8 +75,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
using BDataType = typename Impl::BDataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
|
||||
using BVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
@@ -111,33 +112,27 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// c = a * b
|
||||
auto c_vec = Impl{}(a_vector.template AsType<typename Impl::AVecType>()[I0],
|
||||
b_vector.template AsType<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
// c += a * b
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
a_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
b_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
@@ -274,8 +269,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
|
||||
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
@@ -311,34 +306,27 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(a_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
a_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
b_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
@@ -355,8 +343,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename vector_type_maker<typename Impl::BVecType, kKIter>::type::type;
|
||||
using BVecType = typename vector_type_maker<typename Impl::AVecType, kKIter>::type::type;
|
||||
using AVecType = array<ADataType, Impl::AVecType::size() * kKIter>;
|
||||
using BVecType = array<BDataType, Impl::BVecType::size() * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
@@ -418,34 +406,27 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
// swap A and B, value and type
|
||||
static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
const auto a_vector = typename vector_type_maker<AVecType, 1>::type{a_vec};
|
||||
const auto b_vector = typename vector_type_maker<BVecType, 1>::type{b_vec};
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
// swap A and B, value and type
|
||||
auto c_vec = Impl{}(b_vector.template AsType<typename Impl::AVecType>()[I0],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[I0]);
|
||||
auto c_vec = Impl{}(b_vec.template get_as<typename Impl::AVecType>()[I0],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[I0]);
|
||||
|
||||
static_for<1, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
b_vector.template AsType<typename Impl::AVecType>()[iKIter],
|
||||
a_vector.template AsType<typename Impl::BVecType>()[iKIter]);
|
||||
b_vec.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
a_vec.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
});
|
||||
|
||||
return c_vec;
|
||||
|
||||
@@ -10,13 +10,13 @@ namespace ck_tile {
|
||||
// FP16
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
{
|
||||
using ADataType = half_t;
|
||||
using BDataType = half_t;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = typename vector_type<half_t, 4>::type;
|
||||
using BVecType = typename vector_type<half_t, 4>::type;
|
||||
using CVecType = typename vector_type<float, 16>::type;
|
||||
using AVecType = array<fp16_t, 4>;
|
||||
using BVecType = array<fp16_t, 4>;
|
||||
using CVecType = array<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -36,25 +36,37 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
fp32x16_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
};
|
||||
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
{
|
||||
using ADataType = half_t;
|
||||
using BDataType = half_t;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = typename vector_type<half_t, 4>::type;
|
||||
using BVecType = typename vector_type<half_t, 4>::type;
|
||||
using CVecType = typename vector_type<float, 4>::type;
|
||||
using AVecType = array<fp16_t, 4>;
|
||||
using BVecType = array<fp16_t, 4>;
|
||||
using CVecType = array<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
@@ -74,26 +86,38 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<fp16x4_t>()[number<0>{}],
|
||||
fp32x4_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
};
|
||||
|
||||
// Bf16
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
{
|
||||
using ADataType = bhalf_t;
|
||||
using BDataType = bhalf_t;
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = typename vector_type<bhalf_t, 4>::type;
|
||||
using BVecType = typename vector_type<bhalf_t, 4>::type;
|
||||
using CVecType = typename vector_type<float, 16>::type;
|
||||
using AVecType = array<bf16_t, 4>;
|
||||
using BVecType = array<bf16_t, 4>;
|
||||
using CVecType = array<float, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -113,25 +137,37 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
fp32x16_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
};
|
||||
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
{
|
||||
using ADataType = bhalf_t;
|
||||
using BDataType = bhalf_t;
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = typename vector_type<bhalf_t, 4>::type;
|
||||
using BVecType = typename vector_type<bhalf_t, 4>::type;
|
||||
using CVecType = typename vector_type<float, 4>::type;
|
||||
using AVecType = array<bf16_t, 4>;
|
||||
using BVecType = array<bf16_t, 4>;
|
||||
using CVecType = array<float, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
@@ -151,13 +187,25 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
c_vec.template get_as<fp32x4_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0);
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
b_vec.template get_as<bf16x4_t>()[number<0>{}],
|
||||
fp32x4_t{0.f},
|
||||
0,
|
||||
0,
|
||||
0));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -169,9 +217,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = typename vector_type<ADataType, 8>::type;
|
||||
using BVecType = typename vector_type<BDataType, 8>::type;
|
||||
using CVecType = typename vector_type<CDataType, 16>::type;
|
||||
using AVecType = array<ADataType, 8>;
|
||||
using BVecType = array<BDataType, 8>;
|
||||
using CVecType = array<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
@@ -192,27 +240,49 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec),
|
||||
bit_cast<long>(b_vec),
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<ADataType, 8> a_(a_vec);
|
||||
vector_type<BDataType, 8> b_(b_vec);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 = type_convert<float>(a_.template AsType<ADataType>()[number<k>{}]);
|
||||
float b_f32 = type_convert<float>(b_.template AsType<BDataType>()[number<k>{}]);
|
||||
float a_f32 = type_convert<float>(a_vec.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 = type_convert<float>(b_vec.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
c_vec.template get_as<fp32x16_t>()[number<0>{}] = __builtin_amdgcn_mfma_f32_32x32x2f32(
|
||||
a_f32, b_f32, c_vec.template get_as<fp32x16_t>()[number<0>{}], 0, 0, 0);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
@@ -220,18 +290,18 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, fp8_t>)
|
||||
return __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, fp8_t> && is_same_v<BDataType, bf8_t>)
|
||||
return __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, fp8_t>)
|
||||
return __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
|
||||
else if constexpr(is_same_v<ADataType, bf8_t> && is_same_v<BDataType, bf8_t>)
|
||||
return __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0);
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -29,14 +30,14 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bhalf_t, ck_tile::bhalf_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
|
||||
@@ -33,9 +33,9 @@ struct WarpGemmImpl
|
||||
|
||||
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
|
||||
{
|
||||
using AVec = typename vector_type<ADataType, AWarpTensor::get_thread_buffer_size()>::type;
|
||||
using BVec = typename vector_type<BDataType, BWarpTensor::get_thread_buffer_size()>::type;
|
||||
using CVec = typename vector_type<CDataType, CWarpTensor::get_thread_buffer_size()>::type;
|
||||
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -53,9 +53,9 @@ struct WarpGemmImpl
|
||||
{
|
||||
CWarpTensor c;
|
||||
|
||||
using AVec = typename vector_type<ADataType, AWarpTensor::get_thread_buffer_size()>::type;
|
||||
using BVec = typename vector_type<BDataType, BWarpTensor::get_thread_buffer_size()>::type;
|
||||
using CVec = typename vector_type<CDataType, CWarpTensor::get_thread_buffer_size()>::type;
|
||||
using AVec = array<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = array<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = array<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
|
||||
@@ -4,3 +4,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -188,7 +188,7 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
|
||||
using InDataType = typename InDistributedTensor_::DataType;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
|
||||
static_assert(is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
|
||||
static_assert(std::is_same_v<InDataType, remove_cvref_t<InDataType_>>, "wrong!");
|
||||
|
||||
// declare acc_tensor
|
||||
constexpr auto acc_dstr =
|
||||
|
||||
@@ -2,9 +2,11 @@ import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import os
|
||||
import copy
|
||||
|
||||
NS = 'ck_tile'
|
||||
OPS = 'ops'
|
||||
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
|
||||
|
||||
HEADER_COMMON = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
@@ -54,6 +56,15 @@ class submodule_t:
|
||||
f.write(f'#include \"{header_path}\"\n')
|
||||
f.write('\n')
|
||||
# print(self.m)
|
||||
# restructure common
|
||||
for k, v in self.m.items():
|
||||
if k == OPS and OPS_COMMON in v.keys():
|
||||
common_list = copy.deepcopy(v[OPS_COMMON])
|
||||
# v.pop(OPS_COMMON)
|
||||
for km in v.keys():
|
||||
if km != OPS_COMMON:
|
||||
v[km].extend(common_list)
|
||||
|
||||
for k, v in self.m.items():
|
||||
if k == OPS:
|
||||
for km, kv in v.items():
|
||||
|
||||
Reference in New Issue
Block a user