mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
remove ck_tile example from default cmake target like all/install/check
This commit is contained in:
@@ -49,6 +49,7 @@
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
// minimal arch
|
||||
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
|
||||
22
include/ck_tile/core/utility/ignore.hpp
Normal file
22
include/ck_tile/core/utility/ignore.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
struct ignore_t
|
||||
{
|
||||
template <typename T>
|
||||
constexpr void operator=(T&&) const noexcept
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
inline constexpr detail::ignore_t ignore;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -36,14 +36,28 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -75,14 +89,28 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -115,14 +143,52 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -154,14 +220,52 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -208,7 +312,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
@@ -219,12 +323,17 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
@@ -237,6 +346,24 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user