mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
This commit is contained in:
3066
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
3066
include/ck_tile/core/arch/amd_buffer_addressing.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2947
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2947
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
124
include/ck_tile/core/arch/amd_buffer_coherence.hpp
Normal file
124
include/ck_tile/core/arch/amd_buffer_coherence.hpp
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// memory coherency bit for buffer store/load instruction
|
||||
// check ISA manual for each GFX target
|
||||
// e.g. for
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
|
||||
// page 67~68
|
||||
enum struct amd_buffer_coherence_enum
|
||||
{
|
||||
coherence_default = 0, // default value
|
||||
#if defined(__gfx12__)
|
||||
// Temporal hint
|
||||
RT = 0, // regular temporal
|
||||
NT = 1, // non temporal
|
||||
HT = 2, // high priority temporal
|
||||
LU = 3, // last use (load op)
|
||||
WB = 3, // same as HT, overrides WR in far cache (store op)
|
||||
NT_RT = 4, // non temporal for near cache, regular for far cache
|
||||
RT_NT = 5, // regular for near cache, non-temporal for far cache
|
||||
NT_HT = 6, // non temporal for near cache, high priority for far cache
|
||||
NT_WB = 7, // non temporal for near cache, WB for far cache
|
||||
// (store op, reserved for load op)
|
||||
// Scope
|
||||
CU = 0,
|
||||
SE = 8,
|
||||
DEVICE = 16,
|
||||
SYSTEM = 24,
|
||||
// Temporal Hint for CU
|
||||
CU_RT = RT | CU,
|
||||
CU_NT = NT | CU,
|
||||
CU_HT = HT | CU,
|
||||
CU_LU = LU | CU,
|
||||
CU_WB = WB | CU,
|
||||
CU_NT_RT = NT_RT | CU,
|
||||
CU_RT_NT = RT_NT | CU,
|
||||
CU_NT_HT = NT_HT | CU,
|
||||
CU_NT_WB = NT_WB | CU,
|
||||
// Temporal Hint for SE
|
||||
SE_RT = RT | SE,
|
||||
SE_NT = NT | SE,
|
||||
SE_HT = HT | SE,
|
||||
SE_LU = LU | SE,
|
||||
SE_WB = WB | SE,
|
||||
SE_NT_RT = NT_RT | SE,
|
||||
SE_RT_NT = RT_NT | SE,
|
||||
SE_NT_HT = NT_HT | SE,
|
||||
SE_NT_WB = NT_WB | SE,
|
||||
// Temporal Hint for DEVICE
|
||||
DEVICE_RT = RT | DEVICE,
|
||||
DEVICE_NT = NT | DEVICE,
|
||||
DEVICE_HT = HT | DEVICE,
|
||||
DEVICE_LU = LU | DEVICE,
|
||||
DEVICE_WB = WB | DEVICE,
|
||||
DEVICE_NT_RT = NT_RT | DEVICE,
|
||||
DEVICE_RT_NT = RT_NT | DEVICE,
|
||||
DEVICE_NT_HT = NT_HT | DEVICE,
|
||||
DEVICE_NT_WB = NT_WB | DEVICE,
|
||||
// Temporal Hint for SYSTEM
|
||||
SYSTEM_RT = RT | SYSTEM,
|
||||
SYSTEM_NT = NT | SYSTEM,
|
||||
SYSTEM_HT = HT | SYSTEM,
|
||||
SYSTEM_LU = LU | SYSTEM,
|
||||
SYSTEM_WB = WB | SYSTEM,
|
||||
SYSTEM_NT_RT = NT_RT | SYSTEM,
|
||||
SYSTEM_RT_NT = RT_NT | SYSTEM,
|
||||
SYSTEM_NT_HT = NT_HT | SYSTEM,
|
||||
SYSTEM_NT_WB = NT_WB | SYSTEM,
|
||||
|
||||
// GFX942 and GFX950 compatiblity
|
||||
GROUP_NT0 = CU_RT,
|
||||
GROUP_NT1 = CU_NT,
|
||||
DEVICE_NT0 = DEVICE_RT,
|
||||
DEVICE_NT1 = DEVICE_NT,
|
||||
SYSTEM_NT0 = SYSTEM_RT,
|
||||
SYSTEM_NT1 = SYSTEM_NT,
|
||||
// Other archs compatiblity
|
||||
glc = DEVICE_NT,
|
||||
slc = SYSTEM_NT,
|
||||
glc_slc = DEVICE_NT | SYSTEM_NT,
|
||||
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
#elif defined(__gfx942__) || defined(__gfx950__)
|
||||
|
||||
WAVE = 0,
|
||||
GROUP = 1,
|
||||
DEVICE = 16,
|
||||
SYSTEM = 17,
|
||||
NT0 = 0,
|
||||
NT1 = 2,
|
||||
|
||||
WAVE_NT0 = NT0 | WAVE,
|
||||
WAVE_NT1 = NT1 | WAVE,
|
||||
GROUP_NT0 = NT0 | GROUP,
|
||||
GROUP_NT1 = NT1 | GROUP,
|
||||
DEVICE_NT0 = NT0 | DEVICE,
|
||||
DEVICE_NT1 = NT1 | DEVICE,
|
||||
SYSTEM_NT0 = NT0 | SYSTEM,
|
||||
SYSTEM_NT1 = NT1 | SYSTEM,
|
||||
|
||||
// Other archs compatiblity
|
||||
glc = DEVICE_NT1,
|
||||
slc = SYSTEM_NT1,
|
||||
glc_slc = DEVICE_NT1 | SYSTEM_NT1,
|
||||
#else
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
|
||||
// Other archs compatiblity
|
||||
DEVICE_NT0 = 0,
|
||||
SYSTEM_NT0 = 0,
|
||||
DEVICE_NT1 = glc,
|
||||
SYSTEM_NT1 = slc,
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
88
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, index_t LaneGroupSize = 16, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
static_assert(LaneGroupSize == 16 || LaneGroupSize == 32 || LaneGroupSize == 64,
|
||||
"LaneGroupSize must be 16, 32, or 64");
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
template <typename T, index_t LaneGroupSize>
|
||||
struct LaneGroupTransposeTraits<T, LaneGroupSize, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = LaneGroupSize;
|
||||
|
||||
static constexpr index_t ksecondDimT = LaneGroupSize;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<1, 2, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2, 3>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 4>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t LaneGroupSize,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
return typename LaneGroupTransposeTraits<T, LaneGroupSize>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
1216
include/ck_tile/core/arch/arch.hpp
Normal file
1216
include/ck_tile/core/arch/arch.hpp
Normal file
File diff suppressed because it is too large
Load Diff
529
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
529
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,529 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename ComputeType>
|
||||
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
|
||||
{
|
||||
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
|
||||
{
|
||||
bf16x2_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
|
||||
{
|
||||
bf16x4_t rtn;
|
||||
rtn[0] = add<bf16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf16_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf16_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf16_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
|
||||
{
|
||||
fp16x2_t rtn;
|
||||
rtn[0] = add<fp16_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp16_t, float>(a[1], b[1]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
|
||||
{
|
||||
fp8x4_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
|
||||
{
|
||||
fp8x8_t rtn;
|
||||
rtn[0] = add<fp8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<fp8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<fp8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<fp8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<fp8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<fp8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<fp8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<fp8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
|
||||
{
|
||||
bf8x4_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
|
||||
{
|
||||
bf8x8_t rtn;
|
||||
rtn[0] = add<bf8_t, float>(a[0], b[0]);
|
||||
rtn[1] = add<bf8_t, float>(a[1], b[1]);
|
||||
rtn[2] = add<bf8_t, float>(a[2], b[2]);
|
||||
rtn[3] = add<bf8_t, float>(a[3], b[3]);
|
||||
rtn[4] = add<bf8_t, float>(a[4], b[4]);
|
||||
rtn[5] = add<bf8_t, float>(a[5], b[5]);
|
||||
rtn[6] = add<bf8_t, float>(a[6], b[6]);
|
||||
rtn[7] = add<bf8_t, float>(a[7], b[7]);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast<bf16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32BF162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf16x2_t* bf162_a;
|
||||
};
|
||||
|
||||
union U32BF162
|
||||
{
|
||||
uint32_t u32;
|
||||
bf16x2_t bf162;
|
||||
};
|
||||
|
||||
U32BF162_ADDR dword_addr;
|
||||
U32BF162 cur_v;
|
||||
U32BF162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.bf162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
|
||||
{
|
||||
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
|
||||
union U64BF164_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf16x4_t* bf164_a;
|
||||
};
|
||||
|
||||
// Union to treat the data as either bf16x4_t or 64-bit integer
|
||||
union U64BF164
|
||||
{
|
||||
uint64_t u64;
|
||||
bf16x4_t bf164;
|
||||
};
|
||||
|
||||
U64BF164_ADDR addr;
|
||||
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
|
||||
|
||||
// First read (non-atomic) of the old value
|
||||
U64BF164 cur_v;
|
||||
cur_v.u64 = *addr.u64_a;
|
||||
|
||||
U64BF164 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
do
|
||||
{
|
||||
// old 64 bits
|
||||
old_v = cur_v.u64;
|
||||
|
||||
// Add elementwise in bf16
|
||||
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt the 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
|
||||
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
|
||||
{
|
||||
union U32FP84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp8x4_t* fp84_a;
|
||||
};
|
||||
|
||||
union U32FP84
|
||||
{
|
||||
uint32_t u32;
|
||||
fp8x4_t fp84;
|
||||
};
|
||||
|
||||
U32FP84_ADDR dword_addr;
|
||||
U32FP84 cur_v;
|
||||
U32FP84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.fp84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
|
||||
{
|
||||
union U32BF84_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
bf8x4_t* bf84_a;
|
||||
};
|
||||
|
||||
union U32BF84
|
||||
{
|
||||
uint32_t u32;
|
||||
bf8x4_t bf84;
|
||||
};
|
||||
|
||||
U32BF84_ADDR dword_addr;
|
||||
U32BF84 cur_v;
|
||||
U32BF84 new_;
|
||||
uint32_t old_v, new_v;
|
||||
|
||||
dword_addr.bf84_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
|
||||
{
|
||||
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
|
||||
union U64FP88_ADDR
|
||||
{
|
||||
uint64_t* u64_a; // pointer to 64-bit integer
|
||||
fp8x8_t* fp88_a; // pointer to fp8x8_t
|
||||
};
|
||||
|
||||
union U64FP88
|
||||
{
|
||||
uint64_t u64;
|
||||
fp8x8_t fp88;
|
||||
};
|
||||
|
||||
U64FP88_ADDR dword_addr;
|
||||
U64FP88 cur_v;
|
||||
U64FP88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
// Point to the destination as both fp8x8_t* and uint64_t*.
|
||||
dword_addr.fp88_a = p_dst;
|
||||
// Initial read of 64 bits from memory
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each fp8 element using your add_fp8x8_t(...) routine
|
||||
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// Attempt 64-bit CAS
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for bf8x8_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
|
||||
{
|
||||
union U64BF88_ADDR
|
||||
{
|
||||
uint64_t* u64_a;
|
||||
bf8x8_t* bf88_a;
|
||||
};
|
||||
|
||||
union U64BF88
|
||||
{
|
||||
uint64_t u64;
|
||||
bf8x8_t bf88;
|
||||
};
|
||||
|
||||
U64BF88_ADDR dword_addr;
|
||||
U64BF88 cur_v;
|
||||
U64BF88 new_v_union;
|
||||
uint64_t old_v, new_v;
|
||||
|
||||
dword_addr.bf88_a = p_dst;
|
||||
// Read the original 64 bits
|
||||
cur_v.u64 = *dword_addr.u64_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u64;
|
||||
// Add each bf8 element using your add_bf8x8_t(...) routine
|
||||
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
|
||||
new_v = new_v_union.u64;
|
||||
|
||||
// 64-bit CAS loop
|
||||
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
|
||||
} while(cur_v.u64 != old_v);
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic add for fp16x2_t
|
||||
//
|
||||
template <>
|
||||
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
{
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
|
||||
#else
|
||||
union U32F162_ADDR
|
||||
{
|
||||
uint32_t* u32_a;
|
||||
fp16x2_t* f162_a;
|
||||
};
|
||||
|
||||
union U32F162
|
||||
{
|
||||
uint32_t u32;
|
||||
fp16x2_t f162;
|
||||
};
|
||||
|
||||
U32F162_ADDR dword_addr;
|
||||
U32F162 cur_v;
|
||||
U32F162 new_;
|
||||
uint32_t old_v, new_v;
|
||||
dword_addr.f162_a = p_dst;
|
||||
cur_v.u32 = *dword_addr.u32_a;
|
||||
|
||||
do
|
||||
{
|
||||
old_v = cur_v.u32;
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
|
||||
"The granularity of the thread buffer is unsupported on the hardware!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, x.template get_as<float>()[I2]);
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, x.template get_as<float>()[I3]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return atomicAdd(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
|
||||
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicAdd(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
|
||||
x.template get_as<bf16x4_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf8_t>::value)
|
||||
{
|
||||
if constexpr(N == 4)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 8)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
}
|
||||
if constexpr(N == 16)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
|
||||
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, fp16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
|
||||
x.template get_as<fp16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, double>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<float>(x));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<double>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<int32_t>(x));
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, uint32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
atomicMax(p_dst, bit_cast<uint32_t>(x));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
128
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
128
include/ck_tile/core/arch/mma/amdgcn_mma.hpp
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct Unsupported
|
||||
* @brief Meta-tag to indicate unsupported amdgcn_mma instance.
|
||||
*/
|
||||
struct Unsupported;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
/**
|
||||
* @concept MmaOpI
|
||||
* @brief Expresses the meta-data interface required for each MmaOp policy.
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
concept MmaOpI = requires(MmaOp op) {
|
||||
// Requires an op context
|
||||
typename MmaOp::OpType;
|
||||
|
||||
// Captures types for inputs / outputs to mma function
|
||||
typename MmaOp::AVecType;
|
||||
typename MmaOp::BVecType;
|
||||
typename MmaOp::CVecType;
|
||||
|
||||
// Captures CK-specific layout properties
|
||||
{ MmaOp::kAMBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNBlock } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kAMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kBNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kABKPerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCMLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCNLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM0PerLane } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOp::kCM1PerLane } -> std::convertible_to<unsigned int>;
|
||||
|
||||
// Static exec function
|
||||
{
|
||||
MmaOp::exec(
|
||||
typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{})
|
||||
} -> std::convertible_to<typename MmaOp::CVecType>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @class amdgcn_mma
|
||||
* @brief This is the default MmaOp policy.
|
||||
* Instances of this class are to be used as MmaOp policies.
|
||||
* Light builtin wrapper for mfma / wmma instructions. This class's job is to
|
||||
* provide a uniform interface to invoke the appropriate instruction
|
||||
* based on the template parameters provided. This interface is to bridge
|
||||
* the gap between the ck_tile API types and the native __builtin types.
|
||||
* @tparam ADataType Datatype of input A
|
||||
* @tparam BDataType Datatype of input B
|
||||
* @tparam CDataType Datatype of accumulator
|
||||
* @tparam BlockM M-dimension of mma block
|
||||
* @tparam BlockN N-dimension of mma block
|
||||
* @tparam BlockK K-dimension of mma block
|
||||
* @tparam CtrlFlags Control flags for mma operation
|
||||
* @tparam CompilerTarget The current compiler target
|
||||
* @tparam Enabler SFINAE enabler
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockK,
|
||||
typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily_,
|
||||
typename Enabler = void>
|
||||
struct amdgcn_mma
|
||||
{
|
||||
// The base instance is unsupported because there is no __builtin to wrap.
|
||||
using OpType = Unsupported;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED;
|
||||
|
||||
// Interface types for A, B, C vectors types
|
||||
using AVecType = ext_vector_t<ADataType, 1>;
|
||||
using BVecType = ext_vector_t<BDataType, 1>;
|
||||
using CVecType = ext_vector_t<CDataType, 1>;
|
||||
|
||||
// Layout constants - default to 0
|
||||
static constexpr index_t kAMBlock = 0;
|
||||
static constexpr index_t kBNBlock = 0;
|
||||
|
||||
static constexpr index_t kAMLane = 0;
|
||||
static constexpr index_t kBNLane = 0;
|
||||
static constexpr index_t kABKLane = 0;
|
||||
static constexpr index_t kABKPerLane = 0;
|
||||
|
||||
static constexpr index_t kCMLane = 0;
|
||||
static constexpr index_t kCNLane = 0;
|
||||
static constexpr index_t kCM0PerLane = 0;
|
||||
static constexpr index_t kCM1PerLane = 0;
|
||||
|
||||
// This is a default pass-through implementation that doesn't do anything practical.
|
||||
CK_TILE_DEVICE static CVecType const&
|
||||
exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC)
|
||||
{
|
||||
ignore(regsA, regsB);
|
||||
return regsC; // No-op, just return C
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma.hpp"
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "sparse/sparse.hpp"
|
||||
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
10
include/ck_tile/core/arch/mma/mfma/mfma.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Include the architecture-specific MFMA implementations and traits
|
||||
#include "mfma_gfx9.hpp"
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_selector.hpp"
|
||||
#include "mfma_transforms.hpp"
|
||||
168
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
168
include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct DefaultMmaCtrlFlags
|
||||
* @brief Default MFMA flags, no broadcasting or rotation of inputs
|
||||
*/
|
||||
struct DefaultMfmaCtrlFlags
|
||||
{
|
||||
static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0
|
||||
static constexpr uint32_t Abid = 0; // ABID flag, default 0
|
||||
static constexpr uint32_t Blgp = 0; // BLGP flag, default 0
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept CtrlFlagsGfx9I
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type on Gfx9
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for Gfx9 MFMA instructions
|
||||
{ CtrlFlags::Cbsz } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Abid } -> std::convertible_to<int>;
|
||||
{ CtrlFlags::Blgp } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX9 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x16 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
// Mfma operation type
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for MFMA on GFX950 targets
|
||||
*
|
||||
* This specialization implements the MFMA instruction for fp16_t A and B
|
||||
* matrices, and fp32_t accumulator matrix, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx9I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_id_t<CompilerTarget, amdgcn_target_id::GFX950>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Packed register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_mfma_f32_16x16x32_f16(aVec,
|
||||
bVec,
|
||||
cVec,
|
||||
static_cast<int>(CtrlFlags::Cbsz),
|
||||
static_cast<int>(CtrlFlags::Abid),
|
||||
static_cast<int>(CtrlFlags::Blgp))};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
195
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
195
include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "mfma_traits.hpp"
|
||||
#include "mfma_gfx9.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MfmaDefaultSelector
|
||||
* @brief Implements a default MFMA selector strategy for gfx9 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported MFMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through
|
||||
implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam BlockKTest Current Block K dimension size to test
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @note Here we assume that BlockKTest is always a power-of-two integer.
|
||||
* The search strategy starts from a maximum BlockKTest size down to 1u by halving
|
||||
* each time.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_gfx9_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct MfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MfmaDefaultSelector
|
||||
* @brief Implements the base case for the default MFMA selector when no supported instruction is
|
||||
* found.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Block M dimension size
|
||||
* @tparam BlockN Block N dimension size
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MfmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp =
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the gfx9 default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported MFMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx9_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp4x4 =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 4u, 4u, 4u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
using CandidateOp16x16 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename MfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename MfmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits4x4 = MmaOpTraits<CandidateOp4x4>;
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported4x4 =
|
||||
CandidateTraits4x4::IsSupported && (FragM % CandidateTraits4x4::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits4x4::BlockN == 0u) && (FragK % CandidateTraits4x4::BlockK == 0u);
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<
|
||||
IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16,
|
||||
CandidateOp16x16,
|
||||
std::conditional_t<IsSupported4x4, CandidateOp4x4, DefaultOp>>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MfmaOp
|
||||
* @brief Meta-tag for the MFMA operation. This will be used in the MmaOp policies to
|
||||
* identify the operation as an MFMA operation.
|
||||
*/
|
||||
struct MfmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_mfma
|
||||
* @brief Trait to check if MmaOp is an MFMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_mfma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_mfma
|
||||
* @brief MmaOp specialization for MFMA operations, confirming the OpType matches MfmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_mfma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, MfmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_mfma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
38
include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx9
|
||||
* @brief Implements the default MMA transforms for gfx9 targets
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx9
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx9 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx9_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx9;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
235
include/ck_tile/core/arch/mma/mma.hpp
Normal file
235
include/ck_tile/core/arch/mma/mma.hpp
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "mma_selector.hpp"
|
||||
#include "mma_traits.hpp"
|
||||
#include "mma_transforms.hpp"
|
||||
|
||||
#include "mfma/mfma.hpp"
|
||||
#include "wmma/wmma.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/*! @enum MmaAccumPolicy
|
||||
* @brief Accumulation order for Mma decomposition
|
||||
*/
|
||||
enum struct MmaAccumPolicy
|
||||
{
|
||||
// Decomposition and accumulation in row-major block order
|
||||
ROW_MAJOR,
|
||||
// Decomposition and accumulation in col-major block order
|
||||
COL_MAJOR
|
||||
};
|
||||
|
||||
/**
|
||||
* @class Mma
|
||||
* @brief Driver for the wave-tile Mma operation. Given a backend block-wise MmaOp implementation
|
||||
* (e.g., mfma or wmma), this class performs block-wise decomposition to matrix-multiply input
|
||||
* fragments of (A: FragM x FragK) x (B: FragK x FragN) and accumulates results into output fragment
|
||||
* (C: FragM x FragN).
|
||||
* @tparam ADataType Data type of input fragment A
|
||||
* @tparam BDataType Data type of input fragment B
|
||||
* @tparam CDataType Data type of input/output fragment C (accumulator)
|
||||
* @tparam FragM Mma fragment M dimension
|
||||
* @tparam FragN Mma fragment K dimension
|
||||
* @tparam FragK Mma fragment M dimension
|
||||
* @tparam AccumPolicy The block order of the accumulation registers (row major or col major block
|
||||
* order)
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam MmaOp The backend wrapper class that will perform block-wise mma op (e.g., mfma or
|
||||
* wmma)
|
||||
* @tparam MmaTransforms The set of transforms to be applied to input/output fragments
|
||||
* @par This is an example of an Mma decomposition driver class that can be used in a wave-tile
|
||||
* context. Given a fragment size, we can decompose the fragment into smaller block-wise mma ops
|
||||
* that are natively supported by the hardware (e.g., mfma or wmma). The class also supports
|
||||
* applying transforms to the input/output fragments as needed (e.g., layout conversions, data type
|
||||
* conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the
|
||||
* output fragment. This is a powerful example of how to build a flexible and reusable mma driver
|
||||
* that can adapt to different hardware capabilities and requirements.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
MmaOpFamily OpFamily,
|
||||
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
|
||||
typename CompilerTarget =
|
||||
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
|
||||
// get_compiler_target(),
|
||||
typename MmaOp =
|
||||
typename MmaDefaultSelector<ADataType, // TODO: c++20 MmaOpI MmaOp = typename
|
||||
// MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily>::SelectedOp,
|
||||
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
|
||||
typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
|
||||
struct WaveWiseMma
|
||||
{
|
||||
|
||||
using BlockWiseMmaOp = MmaOp;
|
||||
using BlockWiseMmaOpTraits = MmaOpTraits<BlockWiseMmaOp>;
|
||||
|
||||
// Block dimensions
|
||||
constexpr static uint32_t BlockM = BlockWiseMmaOpTraits::BlockM;
|
||||
constexpr static uint32_t BlockN = BlockWiseMmaOpTraits::BlockN;
|
||||
constexpr static uint32_t BlockK = BlockWiseMmaOpTraits::BlockK;
|
||||
|
||||
// Block counts for decomposition
|
||||
constexpr static uint32_t BlocksM = FragM / BlockM;
|
||||
constexpr static uint32_t BlocksN = FragN / BlockN;
|
||||
constexpr static uint32_t BlocksK = FragK / BlockK;
|
||||
constexpr static uint32_t BlocksC = BlocksM * BlocksN;
|
||||
|
||||
// Vector types for packed registers in each block
|
||||
using AVecType = typename BlockWiseMmaOpTraits::AVecType;
|
||||
using BVecType = typename BlockWiseMmaOpTraits::BVecType;
|
||||
using CVecType = typename BlockWiseMmaOpTraits::CVecType;
|
||||
|
||||
// Buffer types for fragments
|
||||
using ABufferType = AVecType[BlocksM][BlocksK];
|
||||
using BBufferType = BVecType[BlocksN][BlocksK];
|
||||
using CBufferType = CVecType[BlocksM][BlocksN];
|
||||
|
||||
// Transforms
|
||||
using ATransform = typename MmaTransforms::ATransform;
|
||||
using BTransform = typename MmaTransforms::BTransform;
|
||||
using CTransform = typename MmaTransforms::CTransform;
|
||||
using DTransform = typename MmaTransforms::DTransform;
|
||||
|
||||
// Sanity checks
|
||||
static_assert(FragM >= BlockM, "FragM must be larger than BlockM");
|
||||
static_assert(FragN >= BlockN, "FragN must be larger than BlockN");
|
||||
static_assert(FragK >= BlockK, "FragK must be larger than BlockK");
|
||||
static_assert(FragM % BlockM == 0u, "FragM must be a multiple of BlockM");
|
||||
static_assert(FragN % BlockN == 0u, "FragN must be a multiple of BlockN");
|
||||
static_assert(FragK % BlockK == 0u, "FragK must be a multiple of BlockK");
|
||||
|
||||
private:
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT const&>(inputBuffer);
|
||||
}
|
||||
|
||||
template <typename DstT, typename SrcT>
|
||||
CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer)
|
||||
{
|
||||
// TODO: Implement formatting logic as needed.
|
||||
// This is intended to convert input fragments to the native vector types
|
||||
// required by the BlockWiseMma operation for iteration
|
||||
static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer");
|
||||
return reinterpret_cast<DstT&>(inputBuffer);
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Col-major" accumulation over the M-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in col-major order
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
/*! @brief Execute Mma in row-major accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
// We implement an example wave-tile pipeline here.
|
||||
// First, we apply the necessary transforms to the input fragments,
|
||||
// then we convert the result into buffers of native vector formats
|
||||
// that we can easily index. Native vector formats are necessary inputs
|
||||
// to the given MmaOp exec function.
|
||||
auto a_frag = formatBuffer<ABufferType>(ATransform::exec(a));
|
||||
auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
|
||||
auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
|
||||
|
||||
// "Row-major" accumulation over the N-dimension blocks first.
|
||||
// Pseudo code here, but we would basically iterate over the blocks in row-major order.
|
||||
// We also have to ensure that the incoming vector fragments are converted to native vector
|
||||
// types before passing to the BlockWiseMma exec function.
|
||||
for(uint32_t bm = 0u; bm < BlocksM; ++bm)
|
||||
{
|
||||
for(uint32_t bn = 0u; bn < BlocksN; ++bn)
|
||||
{
|
||||
for(uint32_t bk = 0u; bk < BlocksK; ++bk)
|
||||
{
|
||||
c_frag[bm][bn] =
|
||||
BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert native vector results back to the output fragment format
|
||||
// and then return after we apply the final output transform.
|
||||
return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
|
||||
}
|
||||
|
||||
public:
|
||||
/*! @brief Forward to Mma operation with specified accumulation order.
|
||||
* @tparam VecTA The input fragment A vector type
|
||||
* @tparam VecTB The input fragment B vector type
|
||||
* @tparam VecTC The input/output fragment C vector type
|
||||
*/
|
||||
template <typename VecTA, typename VecTB, typename VecTC>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum)
|
||||
{
|
||||
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
|
||||
{
|
||||
return exec_row_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
|
||||
{
|
||||
return exec_col_major(
|
||||
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
48
include/ck_tile/core/arch/mma/mma_op_family.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum MmaOpFamily
|
||||
* @brief Enumeration that defines mma op families and
|
||||
*/
|
||||
enum struct MmaOpFamily
|
||||
{
|
||||
UNDEFINED = 0,
|
||||
DENSE,
|
||||
SPARSE,
|
||||
SCALE,
|
||||
};
|
||||
|
||||
/**
|
||||
* @class is_ctrl_fis_mma_op_of_familylag_of_family
|
||||
* @brief Meta-function to check if MmaOp is of the specified MmaOpFamily
|
||||
* @tparam Family Control flag family
|
||||
* @tparam MmaOp amdgcn struct specialization type
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp, typename = void>
|
||||
struct is_mma_op_of_family : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_of_family
|
||||
* @brief Specialization for Family == MmaOp::OpFamily detection
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
struct is_mma_op_of_family<Family, MmaOp, std::enable_if_t<Family == MmaOp::OpFamily>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_of_family trait
|
||||
* @tparam Family Desired control flag family
|
||||
* @tparam MmaOp The amdgcn struct specialization type to check
|
||||
*/
|
||||
template <MmaOpFamily Family, typename MmaOp>
|
||||
static constexpr bool is_mma_op_of_family_v = is_mma_op_of_family<Family, MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
75
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
75
include/ck_tile/core/arch/mma/mma_selector.hpp
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class MmaDefaultSelector
|
||||
* @brief Implements a default mma selector strategy for the current target architecture.
|
||||
* This is simply intended as a default selection strategy for mma instruction operations.
|
||||
* Given the particular datatypes and Fragment dimensions, the selector will attempt to
|
||||
* select the instruction with the largest K dimension that is supported on the current target
|
||||
* architecture.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Fragment M dimension
|
||||
* @tparam FragN Fragment N dimension
|
||||
* @tparam FragK Fragment K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
* @tparam Enable SFINAE enabler
|
||||
* @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA
|
||||
* operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes
|
||||
* correspond to the size of the individual MMA instructions being used to compute the overall in
|
||||
* block-wise. The Fragment sizes must be multiples of the Block sizes and in general larger than or
|
||||
* equal to the Block sizes.
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily,
|
||||
typename Enable = void>
|
||||
// TODO c++20 requires
|
||||
struct MmaDefaultSelector
|
||||
{
|
||||
// By default, no selection is made, and we fall back to a pass-through unsupported
|
||||
// implementation. This is because we do not have any knowledge of the target architecture here.
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
/**
|
||||
* @concept MmaSelectorI
|
||||
* @brief Expresses the required members for each MmaSelector class.
|
||||
*/
|
||||
template <typename MmaSelector>
|
||||
concept MmaSelectorI = requires(MmaSelector op) {
|
||||
// Selectors should have a resulting SelectedOp type
|
||||
typename MmaSelector::SelectedOp;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the implementations
|
||||
#include "wmma/wmma_selector.hpp"
|
||||
#include "mfma/mfma_selector.hpp"
|
||||
164
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
164
include/ck_tile/core/arch/mma/mma_traits.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "mfma/mfma_traits.hpp"
|
||||
#include "wmma/wmma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class is_mma_op_supported
|
||||
* @brief Trait to check if MmaOp is supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, typename = void>
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_supported : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_supported
|
||||
* @brief The MmaOp is unsupported specialization
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_supported<MmaOp,
|
||||
std::enable_if_t<std::is_same_v<typename MmaOp::OpType, Unsupported>>>
|
||||
: std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluation of is_mma_op_supported
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_supported_v = is_mma_op_supported<MmaOp>::value;
|
||||
|
||||
/**
|
||||
* @class MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
template <typename MmaOp>
|
||||
struct MmaOpParams;
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept MmaOpParamsI
|
||||
* @brief Expresses the required members for each MmaOp
|
||||
*/
|
||||
template <typename MmaOpParams>
|
||||
concept MmaOpParamsI = requires(MmaOpParams op) {
|
||||
// Capture template parameters
|
||||
typename MmaOpParams::ADataType;
|
||||
typename MmaOpParams::BDataType;
|
||||
typename MmaOpParams::CDataType;
|
||||
typename MmaOpParams::CtrlFlags;
|
||||
|
||||
{ MmaOpParams::BlockM } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockN } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::BlockK } -> std::convertible_to<unsigned int>;
|
||||
{ MmaOpParams::GfxTargetId } -> std::convertible_to<amdgcn_target_arch_id>;
|
||||
{ MmaOpParams::Family } -> std::convertible_to<MmaOpFamily>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct MmaOpParams
|
||||
* @brief Reflects the template parameters of a given MmaOp
|
||||
* @tparam ADataType_ Data type of matrix A
|
||||
* @tparam BDataType_ Data type of matrix B
|
||||
* @tparam CDataType_ Data type of the accumulator
|
||||
* @tparam BlockM_ Size of the M dimension
|
||||
* @tparam BlockN_ Size of the N dimension
|
||||
* @tparam BlockK_ Size of the K dimension
|
||||
* @tparam CtrlFlags_ Control flags for the MMA operation
|
||||
* @tparam CompilerTarget_ The compiler target
|
||||
*/
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
uint32_t BlockM_,
|
||||
uint32_t BlockN_,
|
||||
uint32_t BlockK_,
|
||||
typename CtrlFlags_,
|
||||
typename CompilerTarget_,
|
||||
MmaOpFamily OpFamily_>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget_>
|
||||
struct MmaOpParams<amdgcn_mma<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockM_,
|
||||
BlockN_,
|
||||
BlockK_,
|
||||
CtrlFlags_,
|
||||
CompilerTarget_,
|
||||
OpFamily_>>
|
||||
{
|
||||
// Capture incoming template parameters
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using CDataType = CDataType_;
|
||||
static constexpr uint32_t BlockM = BlockM_;
|
||||
static constexpr uint32_t BlockN = BlockN_;
|
||||
static constexpr uint32_t BlockK = BlockK_;
|
||||
using CtrlFlags = CtrlFlags_;
|
||||
using CompilerTarget = CompilerTarget_;
|
||||
static constexpr auto MmaOpFamily = OpFamily_;
|
||||
// TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaOpTraits
|
||||
* @brief Reflects the template parameters and static members of a given MmaOp.
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 template <MmaOpI MmaOp>
|
||||
// TODO: c++20 requires MmaOpParamsI<MmaOpParams<MmaOp>>
|
||||
struct MmaOpTraits : public MmaOpParams<MmaOp>
|
||||
{
|
||||
// Capture internal MmaOp static members
|
||||
using OpType = typename MmaOp::OpType;
|
||||
using AVecType = typename MmaOp::AVecType;
|
||||
using BVecType = typename MmaOp::BVecType;
|
||||
using CVecType = typename MmaOp::CVecType;
|
||||
|
||||
static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily;
|
||||
|
||||
// Capture layout parameters
|
||||
static constexpr index_t kAMBlock = MmaOp::kAMBlock;
|
||||
static constexpr index_t kBNBlock = MmaOp::kBNBlock;
|
||||
static constexpr index_t kAMLane = MmaOp::kAMLane;
|
||||
static constexpr index_t kBNLane = MmaOp::kBNLane;
|
||||
static constexpr index_t kABKLane = MmaOp::kABKLane;
|
||||
static constexpr index_t kABKPerLane = MmaOp::kABKPerLane;
|
||||
static constexpr index_t kCMLane = MmaOp::kCMLane;
|
||||
static constexpr index_t kCNLane = MmaOp::kCNLane;
|
||||
static constexpr index_t kCM0PerLane = MmaOp::kCM0PerLane;
|
||||
static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane;
|
||||
|
||||
// Additional traits to identify the type of MmaOp at compile time
|
||||
constexpr static bool IsMfma = is_mma_op_mfma_v<MmaOp>;
|
||||
constexpr static bool IsWmma = is_mma_op_wmma_v<MmaOp>;
|
||||
constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE;
|
||||
constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE;
|
||||
constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE;
|
||||
constexpr static bool IsSupported =
|
||||
is_mma_op_supported_v<MmaOp> && OpFamily != MmaOpFamily::UNDEFINED;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
49
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
49
include/ck_tile/core/arch/mma/mma_transforms.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct PassThroughTransform
|
||||
* @brief A no-op transform that passes through the input as-is.
|
||||
*/
|
||||
struct PassThroughTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget
|
||||
* @tparam MmaOp The Mma operation type
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam Enable SFINAE parameter for specialization
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget, typename Enable = void>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id CompilerTarget, typename Enable = void>
|
||||
struct MmaTransformsDefaultSelector;
|
||||
|
||||
#if CK_TILE_CONCEPTS
|
||||
|
||||
/**
|
||||
* @concept MmaTransformsI
|
||||
* @brief Expresses the interface of required members for each MmaTransforms type.
|
||||
*/
|
||||
template <typename MmaTransforms>
|
||||
concept MmaTransformsI = requires(MmaTransforms transforms) {
|
||||
// Transforms should define TransformA, TransformB, TransformC, and TransformD types
|
||||
typename MmaTransforms::ATransform;
|
||||
typename MmaTransforms::BTransform;
|
||||
typename MmaTransforms::CTransform;
|
||||
typename MmaTransforms::DTransform;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
151
include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseMfmaDefaultSelector
|
||||
* @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseMfmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate MFMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseMfmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the CDNA default MMA selector strategy for sparse MFMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<std::enable_if_t<is_any_value_of(CompilerTarget::TARGET_ID,
|
||||
amdgcn_target_id::GFX942,
|
||||
amdgcn_target_id::GFX950)>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common MFMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
using CandidateOp32x32 = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
32u,
|
||||
32u,
|
||||
64u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseMfmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
using CandidateTraits32x32 = MmaOpTraits<CandidateOp32x32>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the MFMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported &&
|
||||
(FragM % CandidateTraits32x32::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits32x32::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits32x32::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported MFMA operation for the given fragment shape
|
||||
using SelectedOp =
|
||||
std::conditional_t<IsSupported32x32,
|
||||
CandidateOp32x32,
|
||||
std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
108
include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DefaultSparseMfmaCtrlFlags
|
||||
* @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is
|
||||
* 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit.
|
||||
*/
|
||||
struct DefaultSparseMfmaCtrlFlags
|
||||
{
|
||||
static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST;
|
||||
};
|
||||
|
||||
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
#include <concepts>
|
||||
|
||||
/**
|
||||
* @concept SparseMfmaCtrlFlags
|
||||
* @brief Expresses the interface of required members for each CtrlFlags type
|
||||
*/
|
||||
template <typename CtrlFlags>
|
||||
concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) {
|
||||
// Flag members for sparse MFMA instructions
|
||||
{ CtrlFlags::CompressionIndex } -> std::convertible_to<SparseCompressionIndex>;
|
||||
};
|
||||
|
||||
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets
|
||||
*
|
||||
* This specialization implements the SMFMA instruction for fp16_t A and B
|
||||
* matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes.
|
||||
*
|
||||
* @tparam CtrlFlags Control flags for the Sparse MFMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseMfmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<
|
||||
fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
std::enable_if_t<is_any_value_of(
|
||||
CompilerTarget::TARGET_ID, amdgcn_target_id::GFX942, amdgcn_target_id::GFX950)>>
|
||||
{
|
||||
using OpType = MfmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 8;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 4>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 4);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]};
|
||||
|
||||
using namespace sparse::detail;
|
||||
static constexpr BuiltinParams PARAMS = getBuiltinParams<CtrlFlags::CompressionIndex>();
|
||||
return {__builtin_amdgcn_smfmac_f32_16x16x32_f16(
|
||||
a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
68
include/ck_tile/core/arch/mma/sparse/sparse.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum SparseCompressionIndex
|
||||
* @brief Indicates which set of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data)
|
||||
* of index information for a lane. \see DefaultSparseMfmaCtrlFlags
|
||||
*/
|
||||
enum struct SparseCompressionIndex : int
|
||||
{
|
||||
FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively
|
||||
SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively
|
||||
THIRD = 2, // Uses bits [23:16]
|
||||
FOURTH = 3, // Uses bits [31:24]
|
||||
};
|
||||
|
||||
namespace sparse::detail {
|
||||
|
||||
/**
|
||||
* @struct BuiltinParams
|
||||
* @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse
|
||||
* builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data:
|
||||
* If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected
|
||||
* (VGPR[srcC][7..0]).
|
||||
*
|
||||
* 8-bit source data:
|
||||
* If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC
|
||||
* containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected
|
||||
* (VGPR[srcC][15..0]).
|
||||
*/
|
||||
struct BuiltinParams
|
||||
{
|
||||
int UseFirstIndex; // CBSZ
|
||||
int ByteIndexToOverride; // ABID
|
||||
};
|
||||
|
||||
template <SparseCompressionIndex Idx>
|
||||
static constexpr BuiltinParams getBuiltinParams()
|
||||
{
|
||||
BuiltinParams params;
|
||||
if constexpr(Idx == SparseCompressionIndex::FIRST)
|
||||
{
|
||||
params.UseFirstIndex = 1;
|
||||
params.ByteIndexToOverride = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
params.UseFirstIndex = 0;
|
||||
params.ByteIndexToOverride = static_cast<int>(Idx);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
} // namespace sparse::detail
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include sparse MFMA traits and architecture-specific implementations
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
||||
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
7
include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
||||
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
48
include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsSparse
|
||||
* @brief Implements the default transforms for Sparse
|
||||
*
|
||||
* For 2:4 structured sparsity with inline register metadata:
|
||||
* - ATransform: Pass-through (sparse operands formatted in Exec) TODO!
|
||||
* - BTransform: Pass-through (sparse operands already formatted)
|
||||
* - CTransform: Pass-through (input accumulator)
|
||||
* - DTransform: Pass-through (output accumulator as-is)
|
||||
*/
|
||||
struct MmaDefaultTransformsSparse
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class MmaTransformsDefaultSelector
|
||||
* @brief Specialization for Sparse MFMA transforms
|
||||
* Provides default transform selection for sparse operations
|
||||
*
|
||||
* @tparam MmaOp Sparse MMA operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires(is_mma_op_sparse(MmaOp))
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
std::enable_if_t<MmaOp::OpFamily == MmaOpFamily::SPARSE>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsSparse;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
134
include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class SparseWmmaDefaultSelector
|
||||
* @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct SparseWmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
DefaultSparseWmmaCtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
void,
|
||||
amdgcn_target<>,
|
||||
MmaOpFamily::UNDEFINED>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the RDNA default MMA selector strategy for sparse WMMA.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_family_gfx12_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::SPARSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension WMMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp = typename SparseWmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
1u,
|
||||
1u,
|
||||
1u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
73
include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
struct DefaultSparseWmmaCtrlFlags
|
||||
{
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
32u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::SPARSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE;
|
||||
|
||||
static constexpr index_t ABVecN = 16;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using BVecType = ext_vector_t<fp16_t, ABVecN>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kCompressionRatio = 2;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
static constexpr index_t CompressedSize = ABVecN / kCompressionRatio;
|
||||
using AVecCompressed = ext_vector_t<fp16_t, CompressedSize>;
|
||||
static_assert(CompressedSize == 8);
|
||||
// TODO: Compressing A on-the-fly should be OK for now, but we need to validate
|
||||
// and evaluate changing this to a transform at a higher level.
|
||||
// aVec not being const can cause problems when running multiple intrinsics.
|
||||
const int32_t idx = ck_tile::compress_a_impl<fp16_t, CompressedSize>(aVec);
|
||||
|
||||
const AVecCompressed a_vec_pruned = {
|
||||
aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]};
|
||||
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* @file tile_distribution_encoding_register_mapper.hpp
|
||||
* @brief Utility for register / matrix coordinate mapping from TileDistributionEncoding
|
||||
* @details Defines TileDistrEncRegMap, which takes a TileDistributionEncoding and provides
|
||||
* functions for mapping matrix fragment coordinates to register coordinates (lane, vector item) and
|
||||
* vice versa. This is only meant for tile distributions encodings that describe register mappings.
|
||||
*
|
||||
* A repeat dimension is allowed in which case multiple (lane, vector item) pairs are mapped to the
|
||||
* same matrix coordinates. The inverse map takes a "repeat index" to distinguish between them.
|
||||
*
|
||||
* print() functions are included for printing dimensions and formatted forward and backwards
|
||||
* mappings similar to the AMD Matrix Calculator.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdio.h>
|
||||
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// Utility to calculate register mappings from a Tile Distribution Encoding.
|
||||
template <typename TileDistrEnc>
|
||||
struct TileDistrEncRegMap
|
||||
{
|
||||
// Make sure this is a proper Tile Distr Encoding for Lane Vector mapping.
|
||||
static_assert(TileDistrEnc::NDimR <= 1);
|
||||
static_assert(TileDistrEnc::NDimX == 2);
|
||||
static_assert(TileDistrEnc::NDimP == 1);
|
||||
|
||||
static constexpr auto ps_ys_to_xs_adaptor =
|
||||
make_static_tile_distribution(TileDistrEnc{}).get_ps_ys_to_xs_adaptor();
|
||||
|
||||
static constexpr index_t mat_major_size =
|
||||
container_reduce(typename TileDistrEnc::HsLengthss{}[number<0>{}], multiplies<>{}, 1);
|
||||
static constexpr index_t mat_minor_size =
|
||||
container_reduce(typename TileDistrEnc::HsLengthss{}[number<1>{}], multiplies<>{}, 1);
|
||||
static constexpr index_t num_repeat = [] {
|
||||
if constexpr(TileDistrEnc::NDimR > 0)
|
||||
{
|
||||
return typename TileDistrEnc::RsLengths{}[number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // Necessary to deal with empty "repeat" sequences.
|
||||
}
|
||||
}();
|
||||
static constexpr index_t num_lanes = ps_ys_to_xs_adaptor.get_top_dimension_length(number<0>{});
|
||||
static constexpr index_t num_vector_items =
|
||||
container_reduce(TileDistrEnc::detail::ys_lengths_, multiplies<>{}, 1);
|
||||
|
||||
// Check for 0 dims (will break things much earlier but let's have an extra check).
|
||||
static_assert(mat_major_size > 0);
|
||||
static_assert(mat_minor_size > 0);
|
||||
static_assert(num_repeat > 0);
|
||||
static_assert(num_lanes > 0);
|
||||
static_assert(num_vector_items > 0);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calc_matrix_indices_from_lane_vector(index_t lane_inx, index_t vector_inx)
|
||||
{
|
||||
// For some reason the Y dimension is not treated the same as the P dimension and we need to
|
||||
// manually unmerge the Y dimension index into its hidden indices before being able to use
|
||||
// it...
|
||||
array<index_t, TileDistrEnc::NDimY> y_hidden_inx;
|
||||
for(index_t i = TileDistrEnc::NDimY - 1; i >= 0; --i)
|
||||
{
|
||||
y_hidden_inx[i] = vector_inx % TileDistrEnc::detail::ys_lengths_[i];
|
||||
vector_inx /= TileDistrEnc::detail::ys_lengths_[i];
|
||||
}
|
||||
|
||||
const auto ps_ys_idx = container_concat(array<index_t, 1>{lane_inx}, y_hidden_inx);
|
||||
return ps_ys_to_xs_adaptor.calculate_bottom_index(ps_ys_idx);
|
||||
}
|
||||
|
||||
struct LaneVec
|
||||
{
|
||||
index_t lane = -1; // Sentinel for invalid pairs
|
||||
index_t vec = -1;
|
||||
};
|
||||
|
||||
using InverseMap =
|
||||
std::array<std::array<std::array<LaneVec, num_repeat>, mat_minor_size>, mat_major_size>;
|
||||
|
||||
// TODO: In theory this could be done with inverted merge unmerge operations.
|
||||
CK_TILE_HOST_DEVICE static constexpr InverseMap calc_inverse_map()
|
||||
{
|
||||
InverseMap im{};
|
||||
for(index_t l = 0; l < num_lanes; ++l)
|
||||
{
|
||||
for(index_t v = 0; v < num_vector_items; ++v)
|
||||
{
|
||||
auto res = calc_matrix_indices_from_lane_vector(l, v); // Matrix major, minor inx;
|
||||
|
||||
// We assume that repeated matrix elements appear at increasing L and V indices.
|
||||
for(index_t r = 0; r < num_repeat; r++)
|
||||
{
|
||||
auto& lv = im[res[0]][res[1]][r];
|
||||
if(lv.lane < 0)
|
||||
{
|
||||
lv.lane = l; // TODO: c++20 designated initializers
|
||||
lv.vec = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return im;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_dims()
|
||||
{
|
||||
printf("Matrix dims major, minor, repeat = %d %d %d\n",
|
||||
mat_major_size,
|
||||
mat_minor_size,
|
||||
num_repeat);
|
||||
printf("Num lanes, vector items = %d %d\n", num_lanes, num_vector_items);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_mapping()
|
||||
{
|
||||
printf("(lane, vector) item to matrix element\n L | ");
|
||||
for(index_t v = 0; v < num_vector_items; v++)
|
||||
{
|
||||
printf("vec%2d | ", v);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
for(index_t l = 0; l < num_lanes; l++)
|
||||
{
|
||||
printf("%2d | ", l);
|
||||
for(index_t v = 0; v < num_vector_items; v++)
|
||||
{
|
||||
auto res = calc_matrix_indices_from_lane_vector(l, v);
|
||||
printf("%2d %2d | ", res[0], res[1]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print_inverse_mapping()
|
||||
{
|
||||
InverseMap im = calc_inverse_map();
|
||||
printf("Matrix element to (lane, vector item). Elements are replicated an additional %d "
|
||||
"time(s) in higher lanes. \n",
|
||||
num_repeat - 1);
|
||||
printf("Mat| ");
|
||||
for(index_t k = 0; k < mat_minor_size; k++)
|
||||
{
|
||||
printf(" %2d | ", k);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
for(index_t m = 0; m < mat_major_size; m++)
|
||||
{
|
||||
printf("%2d | ", m);
|
||||
for(index_t k = 0; k < mat_minor_size; k++)
|
||||
{
|
||||
printf("%2d %2d | ", im[m][k][0].lane, im[m][k][0].vec);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
print_dims();
|
||||
print_mapping();
|
||||
print_inverse_mapping();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
34
include/ck_tile/core/arch/mma/wmma/wmma.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @enum WmmaCtrlFlags
|
||||
* @brief Common wmma control flags for gfx11 and gfx12
|
||||
*/
|
||||
enum struct WmmaCtrlFlags : bool
|
||||
{
|
||||
// Only has an effect on gfx11 when the accumulator is 16-bit
|
||||
// Determines which half of the 32-bit accum register to use
|
||||
// Low = bits [15:0]
|
||||
// High = bits[31:16]
|
||||
LOW = false,
|
||||
HIGH = true,
|
||||
|
||||
// Only has an effect on gfx11 / 12 when the input is 8-bit int
|
||||
// Signage indicator of inputs / accum
|
||||
UNSIGNED = false,
|
||||
SIGNED = true
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
|
||||
// Include the architecture-specific WMMA implementations and traits
|
||||
#include "wmma_gfx11.hpp"
|
||||
#include "wmma_gfx12.hpp"
|
||||
#include "wmma_selector.hpp"
|
||||
#include "wmma_traits.hpp"
|
||||
#include "wmma_transforms.hpp"
|
||||
112
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
112
include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
// TODO: Specifically for gfx11 wmma, we need to deal with quirks such as:
|
||||
// - Duplicating A and B inputs
|
||||
// - Handling C / D is always in b32, even for f16 accumulation.
|
||||
// NOTE: Two suggestions:
|
||||
// 1) We could do it here in the wrappers by accepting packed inputs, then swizzling them to
|
||||
// duplicate the inputs as needed before calling the actual built-in. This may introduce
|
||||
// some instruction overhead and violate single responsibility clauses, but keeps the logic
|
||||
// contained within the backend wrapper.
|
||||
// 2) We could do it at a higher level, e.g. in the Mma interface (workflow) by introducing
|
||||
// pre-mma, mma and post-mma steps. The pre-mma step could handle input duplication transform
|
||||
// post-mma could implement D-shuffle transform. This may be cleaner and more flexible than
|
||||
// trying to handle everything in the backend wrappers.
|
||||
//
|
||||
// This current example assumes duplication has already been done, and that C data shuffles have
|
||||
// already been completed. (e.g. option 2 above). These expect duplicated inputs and pre-shuffled
|
||||
// data in C.
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @class DefaultWmmaFlags
|
||||
* @brief Generates default WMMA control flags based on data types.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename CDataType>
|
||||
struct DefaultWmmaCtrlFlags
|
||||
{
|
||||
// Generate default flags for signage
|
||||
// Only used currently for integer inputs / accum in gfx11 / gfx12
|
||||
constexpr static WmmaCtrlFlags InputSignA =
|
||||
std::is_signed_v<ADataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags InputSignB =
|
||||
std::is_signed_v<BDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
constexpr static WmmaCtrlFlags AccumSign =
|
||||
std::is_signed_v<CDataType> ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED;
|
||||
|
||||
// Generate default flags for accumulator destination bits.
|
||||
// Only used if accumulation size is 16-bit in gfx11
|
||||
constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx11I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
std::enable_if_t<is_target_family_gfx11<CompilerTarget>()>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types (duplicated input / b32 accum)
|
||||
using AVecType = ext_vector_t<fp16_t, 16>;
|
||||
using BVecType = ext_vector_t<fp16_t, 16>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
72
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
72
include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "wmma_traits.hpp"
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
// NOTE: At this point forward, we are specializing amdgcn_mma for each target id as needed.
|
||||
// This is because some built-ins are only available on certain target ids.
|
||||
// We can also do things, such add some padding specializations for when we need to use
|
||||
// smaller values of K that aren't directly supported by the built-ins.
|
||||
// For flexibility, it is recommended that for each backend wrapper it supports at least
|
||||
// one packed register for each input to be able to process smaller K values by padding.
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
struct amdgcn_mma<fp16_t,
|
||||
fp16_t,
|
||||
fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
// Wmma operation type
|
||||
using OpType = WmmaOp;
|
||||
static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE;
|
||||
|
||||
// Register types
|
||||
using AVecType = ext_vector_t<fp16_t, 8>;
|
||||
using BVecType = ext_vector_t<fp16_t, 8>;
|
||||
using CVecType = ext_vector_t<fp32_t, 8>;
|
||||
|
||||
// Layout constants
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 8;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 2;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 1;
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType
|
||||
{
|
||||
return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
173
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
173
include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @class WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension search strategy to find the largest supported WMMA
|
||||
* instruction for the given M/N block sizes and datatypes.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam BlockKTest Size of the K dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
uint32_t BlockKTest,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires(is_rdna_arch_id(CompilerTarget) && is_power_of_two_integer(BlockKTest))
|
||||
struct WmmaDefaultSelector
|
||||
{
|
||||
private:
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Define our candidate WMMA implementation for the current parameters
|
||||
using CandidateOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
|
||||
using CandidateTraits = MmaOpTraits<CandidateOp>;
|
||||
|
||||
public:
|
||||
// If the candidate is supported (e.g., a backend implementation exists), then select it.
|
||||
// Otherwise, test another smaller BlockK. If no existing implementations, we will get BlockK=0u
|
||||
// and fall back to the unsupported pass-through implementation.
|
||||
using SelectedOp = std::conditional_t<CandidateTraits::IsSupported,
|
||||
CandidateOp,
|
||||
typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockKTest / 2u,
|
||||
CompilerTarget>::SelectedOp>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct WmmaDefaultSelector
|
||||
* @brief Implements a default WMMA selector strategy for gfx11/12 target architectures.
|
||||
* This implements the K dimension == 1, which is the base case for the recursive K dimension
|
||||
* search. If no supported instruction is found, falls back to an unsupported pass-through
|
||||
* implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam BlockM Size of the M dimension
|
||||
* @tparam BlockN Size of the N dimension
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t BlockM,
|
||||
uint32_t BlockN,
|
||||
typename CompilerTarget>
|
||||
// TODO: c++20 amdgcn_target_arch_id GfxTargetId>
|
||||
struct WmmaDefaultSelector<ADataType, BDataType, CDataType, BlockM, BlockN, 1u, CompilerTarget>
|
||||
{
|
||||
// By default, let's assume no special flags for WMMA
|
||||
using CtrlFlags = DefaultWmmaCtrlFlags<ADataType, BDataType, CDataType>;
|
||||
|
||||
// Default unsupported pass-through if no instruction is found
|
||||
using SelectedOp = amdgcn_mma<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BlockM,
|
||||
BlockN,
|
||||
1u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
MmaOpFamily::DENSE>;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultSelector
|
||||
* @brief Implements the rdna default MMA selector strategy for wave-wise MMA decomposition.
|
||||
* This implements the M/N block size search strategy to find the largest supported WMMA
|
||||
* instruction for the given datatypes.
|
||||
* If no supported instruction is found, falls back to an unsupported pass-through implementation.
|
||||
* @tparam ADataType Data type of matrix A
|
||||
* @tparam BDataType Data type of matrix B
|
||||
* @tparam CDataType Data type of the accumulator
|
||||
* @tparam FragM Size of the M dimension of the fragment to decompose
|
||||
* @tparam FragN Size of the N dimension of the fragment to decompose
|
||||
* @tparam FragK Size of the K dimension of the fragment to decompose
|
||||
* @tparam CompilerTarget The compiler target
|
||||
* @tparam OpFamily The MMA operation family
|
||||
*/
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
uint32_t FragM,
|
||||
uint32_t FragN,
|
||||
uint32_t FragK,
|
||||
typename CompilerTarget,
|
||||
MmaOpFamily OpFamily>
|
||||
// TODO: c++20 amdgcn_target_arch_id CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
struct MmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
FragM,
|
||||
FragN,
|
||||
FragK,
|
||||
CompilerTarget,
|
||||
OpFamily,
|
||||
enable_if_all<enable_if_target_arch_rdna_t<CompilerTarget>,
|
||||
std::enable_if_t<OpFamily == MmaOpFamily::DENSE>>>
|
||||
{
|
||||
private:
|
||||
// Provide the default depth-K search strategy for each class of common WMMA shapes.
|
||||
// Start searching from the largest K dimension MFMA shape down to the smallest.
|
||||
using CandidateOp16x16 = typename WmmaDefaultSelector<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
16u,
|
||||
16u,
|
||||
128u,
|
||||
CompilerTarget>::SelectedOp;
|
||||
|
||||
// Default operation triggers pass-through
|
||||
using DefaultOp =
|
||||
typename WmmaDefaultSelector<ADataType, BDataType, CDataType, 1u, 1u, 1u, CompilerTarget>::
|
||||
SelectedOp;
|
||||
|
||||
// Traits for each candidate
|
||||
using CandidateTraits16x16 = MmaOpTraits<CandidateOp16x16>;
|
||||
|
||||
// Check if each candidate is supported for the given fragment sizes
|
||||
// For this case, we require the fragment sizes to be multiples of the WMMA shape
|
||||
static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported &&
|
||||
(FragM % CandidateTraits16x16::BlockM == 0u) &&
|
||||
(FragN % CandidateTraits16x16::BlockN == 0u) &&
|
||||
(FragK % CandidateTraits16x16::BlockK == 0u);
|
||||
|
||||
public:
|
||||
// Select the largest supported WMMA operation for the given fragment shape
|
||||
using SelectedOp = std::conditional_t<IsSupported16x16, CandidateOp16x16, DefaultOp>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
44
include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct WmmaOp
|
||||
* @brief Meta-tag for the WMMA operation. This will be used in the MmaOp struct to
|
||||
* identify the operation as an WMMA operation.
|
||||
*/
|
||||
struct WmmaOp;
|
||||
|
||||
/**
|
||||
* @class is_mma_op_wmma
|
||||
* @brief Trait to check if MmaOp is an WMMA operation
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp, typename = void>
|
||||
struct is_mma_op_wmma : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct is_mma_op_wmma
|
||||
* @brief MmaOp specialization for WMMA operations, confirming the OpType matches WmmaOp
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
// TODO: c++20 requires
|
||||
struct is_mma_op_wmma<MmaOp, std::enable_if_t<std::is_same_v<typename MmaOp::OpType, WmmaOp>>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Convenience evaluator for is_mma_op_wmma trait
|
||||
* @tparam MmaOp The matrix multiply-accumulate operation type to check
|
||||
*/
|
||||
template <typename MmaOp>
|
||||
static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma<MmaOp>::value;
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
112
include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct DuplicateTransform
|
||||
* @brief Transform to duplicate low register elements to high register elements
|
||||
*/
|
||||
struct DuplicateTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement duplication logic to broadcast low
|
||||
// register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)]
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct PadTransform
|
||||
* @brief Transform to pad data from original type to b32 type
|
||||
*/
|
||||
struct PadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 padding logic.
|
||||
// E.g., for fp16, pad each 16-bit element with 16 bits of 0 to make 32-bit elements
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct UnpadTransform
|
||||
* @brief Transform to unpad data from b32 type to original type
|
||||
*/
|
||||
struct UnpadTransform
|
||||
{
|
||||
template <typename VecType>
|
||||
CK_TILE_DEVICE static decltype(auto) exec(VecType&& v)
|
||||
{
|
||||
// TODO: Implement b32 logic to unpad 32 to original data type.
|
||||
return std::forward<VecType>(v);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx11
|
||||
* @brief Default MMA transforms for GFX11 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx11
|
||||
{
|
||||
using ATransform = DuplicateTransform;
|
||||
using BTransform = DuplicateTransform;
|
||||
using CTransform = PadTransform;
|
||||
using DTransform = UnpadTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaDefaultTransformsGfx12
|
||||
* @brief Default MMA transforms for GFX12 architecture
|
||||
*/
|
||||
struct MmaDefaultTransformsGfx12
|
||||
{
|
||||
using ATransform = PassThroughTransform;
|
||||
using BTransform = PassThroughTransform;
|
||||
using CTransform = PassThroughTransform;
|
||||
using DTransform = PassThroughTransform;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx11 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx11_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx11;
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct MmaTransformsDefaultSelector
|
||||
* @brief Implements the default MMA transforms selection for gfx12 targets
|
||||
* @tparam MmaOp Mma operation
|
||||
* @tparam CompilerTarget The compiler target
|
||||
*/
|
||||
template <typename MmaOp, typename CompilerTarget>
|
||||
// TODO: c++20 template <MmaOpI MmaOp, amdgcn_target_arch_id GfxTargetId>
|
||||
// TODO: c++20 requires
|
||||
struct MmaTransformsDefaultSelector<MmaOp,
|
||||
CompilerTarget,
|
||||
enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
{
|
||||
using SelectedTransforms = MmaDefaultTransformsGfx12;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::core::arch::mma
|
||||
144
include/ck_tile/core/arch/utility.hpp
Normal file
144
include/ck_tile/core/arch/utility.hpp
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: we have "memory" clobber here because this inline asm is used for async copy
|
||||
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
|
||||
}
|
||||
|
||||
// NOTE: this is an immediate value
|
||||
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
|
||||
{
|
||||
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_up(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
{
|
||||
#if 0
|
||||
return __shfl_down(v_local, lane_delta);
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
|
||||
{
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
|
||||
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
|
||||
|
||||
thread_buffer<T, 2> v;
|
||||
v(0) = bit_cast<T>(x[0]);
|
||||
v(1) = bit_cast<T>(x[1]);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
#if 0
|
||||
return __shfl(v_local, src_lane);
|
||||
#elif 1
|
||||
if constexpr(sizeof(int32_t) > sizeof(T))
|
||||
{
|
||||
union packet
|
||||
{
|
||||
int32_t x;
|
||||
T v;
|
||||
};
|
||||
packet p;
|
||||
p.v = v_local;
|
||||
packet p_remote;
|
||||
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
|
||||
|
||||
return p_remote.v;
|
||||
}
|
||||
else if constexpr(sizeof(int32_t) == sizeof(T))
|
||||
{
|
||||
const int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
|
||||
return bit_cast<T>(v_remote_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
|
||||
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
|
||||
using vector_type = thread_buffer<int32_t, elm>;
|
||||
auto vs = bit_cast<vector_type>(v_local);
|
||||
auto vs_remote = vector_type{};
|
||||
static_for<0, elm, 1>{}([&](auto i_e) {
|
||||
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
|
||||
vs_remote(i_e) = tmp;
|
||||
});
|
||||
return bit_cast<T>(vs_remote);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
// per-thread v_flag store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_flag] "v"(v_flag));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
|
||||
{
|
||||
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
|
||||
// per-thread cmp store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_x] "v"(x), [v_y] "v"(y));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
95
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
95
include/ck_tile/core/arch/workgroup_barrier.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct workgroup_barrier
|
||||
{
|
||||
CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
|
||||
|
||||
CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0)
|
||||
{
|
||||
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) != value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Reduces power consumption during polling by leveraging wave-level sleep instructions
|
||||
CK_TILE_DEVICE void wait_eq_wave(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
// Limit active polling to first wave to reduce memory traffic and power
|
||||
const uint32_t wave_size = static_cast<uint32_t>(warpSize);
|
||||
if(threadIdx.x < wave_size)
|
||||
{
|
||||
uint32_t loaded_value = 0;
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
|
||||
while(loaded_value != value)
|
||||
{
|
||||
// s_sleep reduces power draw while waiting, as scalar sleep is cheaper than
|
||||
// busy-wait
|
||||
__builtin_amdgcn_s_sleep(1);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
loaded_value = ld(offset);
|
||||
}
|
||||
loaded_value = __shfl(loaded_value, 0 /*src_lane*/);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(ld(offset) < value) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0)
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// enter critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); }
|
||||
|
||||
// exit critical zoon, assume buffer is zero when launch kernel
|
||||
CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); }
|
||||
|
||||
CK_TILE_DEVICE void inc(uint32_t offset = 0)
|
||||
{
|
||||
__syncthreads();
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(base_ptr + offset, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t* base_ptr;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user