mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
adding int8 direct that reads pre-vectorized data
This commit is contained in:
@@ -133,14 +133,14 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
// 3x3, 34x34, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 4;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
@@ -149,16 +149,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 1x1, 32x32, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned NPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 16;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 4;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
|
||||
@@ -120,7 +120,7 @@ struct vector_type<char, 1>
|
||||
template <>
|
||||
struct vector_type<char, 2>
|
||||
{
|
||||
using MemoryType = char2;
|
||||
using MemoryType = int16_t;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1)
|
||||
{
|
||||
@@ -139,7 +139,7 @@ struct vector_type<char, 2>
|
||||
template <>
|
||||
struct vector_type<char, 4>
|
||||
{
|
||||
using MemoryType = char4;
|
||||
using MemoryType = int32_t;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
|
||||
{
|
||||
@@ -163,6 +163,13 @@ struct vector_type<char, 8>
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int32_t, 2>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct vector_type<char2, 2>
|
||||
{
|
||||
@@ -175,34 +182,30 @@ struct vector_type<char2, 4>
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char4, 1>
|
||||
{
|
||||
using MemoryType = int;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char4, 2>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class TDst, class TSrc0, class TSrc1>
|
||||
__device__ void fused_multiply_accumulate(TDst& d, const TSrc0& s0, const TSrc1& s1)
|
||||
{
|
||||
// static_assert(false, "should not call into base");
|
||||
printf("should not call into base");
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(float& d, const float2& s0, const float2& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(float& d, const float4& s0, const float4& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
@@ -211,13 +214,8 @@ __device__ void fused_multiply_accumulate(float& d, const float4& s0, const floa
|
||||
d += s0.w * s1.w;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
@@ -225,25 +223,25 @@ __device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2&
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x + s0.y * s1.y;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(int32_t& d, const char4& s0, const char4& s1)
|
||||
// TODO:: this interface is misleading, int32 is actually int8x4
|
||||
// need to make a better interface
|
||||
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
|
||||
{
|
||||
#if DEVICE_BACKEND_CUDA
|
||||
#if 1 // debug
|
||||
d = __dp4a(s0, s1, d);
|
||||
#else
|
||||
d += s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w;
|
||||
#elif 1
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" : "=r"(d) : "r"(s0), "r"(s1), "r"(d));
|
||||
#elif 0 // this is wrong! just for debugging
|
||||
d += (*reinterpret_cast<const int32_t*>(&s0)) * (*reinterpret_cast<const int32_t*>(&s1));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user