mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
60 lines
1.7 KiB
C++
60 lines
1.7 KiB
C++
#ifndef CK_CONFIG_NVIDIA_HPP
|
|
#define CK_CONFIG_NVIDIA_HPP
|
|
|
|
#include "cuda_runtime.h"
|
|
#include "cuda_fp16.h"
|
|
#include "nvToolsExt.h"
|
|
#include "helper_cuda.h"
|
|
|
|
#define CK_DEVICE_BACKEND_NVIDIA 1
|
|
#define CK_USE_AMD_INLINE_ASM 0
|
|
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
|
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
|
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
|
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
|
|
|
namespace ck {
|
|
|
|
// For some reason, CUDA need this definition, otherwise
|
|
// compiler won't generate optimal load and store instruction, and
|
|
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
|
// instruction,
|
|
using float2_t = float2;
|
|
using float4_t = float4;
|
|
|
|
using index_t = uint32_t;
|
|
|
|
template <class T>
|
|
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
|
|
{
|
|
d += s0 * s1;
|
|
}
|
|
|
|
#if 0
|
|
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
|
|
|
|
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
|
{
|
|
d += s0.x * s1.x;
|
|
d += s0.y * s1.y;
|
|
}
|
|
|
|
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
|
{
|
|
d += s0.x * s1.x + s0.y * s1.y;
|
|
}
|
|
|
|
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
|
|
|
|
// TODO:: this interface is misleading, s0, s1 are actually int8x4
|
|
// need to make a better interface
|
|
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
|
|
{
|
|
d = __dp4a(s0, s1, d);
|
|
}
|
|
#endif
|
|
|
|
} // namespace ck
|
|
|
|
#endif
|