mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Refactor for MIOpen integration (#4)
Refactor, so can bring multi-index transformation and padding support into MIOpen
This commit is contained in:
@@ -1,84 +1,31 @@
|
||||
#ifndef CK_AMD_INLINE_ASM_HPP
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "vector_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// cast a pointer of LDS to its address
|
||||
extern "C" __attribute__((address_space(3))) __device__ void* __to_local(void* p);
|
||||
|
||||
__device__ void vmcnt(index_t cnt)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
// disable inline asm due to the compiler issue: SWDEV-202749
|
||||
///\to-do: enable the inline asm after the compiler fix
|
||||
#if CK_WORKAROUND_SWDEV_202749
|
||||
c0 += a * b0;
|
||||
c1 += a * b1;
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %2, %3 \n \
|
||||
v_mac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void lgkmcnt(index_t cnt)
|
||||
{
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(3) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(4) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float* a, const float* b, float* c)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(
|
||||
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
@@ -86,596 +33,122 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c)
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
|
||||
: "v"(a[0]),
|
||||
"v"(b[0]),
|
||||
"v"(b[1]),
|
||||
"v"(b[2]),
|
||||
"v"(b[3]),
|
||||
"0"(c[0]),
|
||||
"1"(c[1]),
|
||||
"2"(c[2]),
|
||||
"3"(c[3]));
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
|
||||
{
|
||||
outerProduct1x4(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c));
|
||||
}
|
||||
|
||||
__device__ void outerProduct2x4(const vector_type<float, 2>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1)
|
||||
{
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
}
|
||||
|
||||
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1,
|
||||
vector_type<float, 4>::MemoryType& c2,
|
||||
vector_type<float, 4>::MemoryType& c3)
|
||||
{
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
outerProduct1x4(a.z, b, c2);
|
||||
outerProduct1x4(a.w, b, c3);
|
||||
}
|
||||
|
||||
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
|
||||
const vector_type<float, 4>::MemoryType* b,
|
||||
vector_type<float, 4>::MemoryType* c)
|
||||
{
|
||||
outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]);
|
||||
outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]);
|
||||
outerProduct4x4(a[1], b[0], c[8], c[10], c[12], c[14]);
|
||||
outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]);
|
||||
}
|
||||
|
||||
__device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:0\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 64)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:64\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 128)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:128\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 192)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:192\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 256)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:256\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 320)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:320\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 384)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:384\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 448)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:448\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 512)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:512\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 576)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:576\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 640)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:640\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 704)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:704\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 768)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:768\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 832)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:832\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 896)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:896\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 960)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:960\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1024)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1024\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1088)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1088\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1152)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1152\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1216)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1216\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1280)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1280\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1344)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1344\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1408)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1408\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1472)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1472\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1536)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1536\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1600)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1600\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1664)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1664\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1728)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1728\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1792)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1792\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1856)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1856\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1920)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1920\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1984)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1984\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2048)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2048\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2112)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2112\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2176)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2176\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2240)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2240\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2304)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2304\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2368)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2368\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2432)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2432\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2496)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2496\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2560)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2560\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2624)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2624\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2688)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2688\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2752)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2752\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2816)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2816\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2880)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2880\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2944)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2944\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3008)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3008\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3072)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3072\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3136)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3136\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3200)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3200\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3264)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3264\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3328)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3328\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3392)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3392\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3456)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3456\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3520)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3520\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3584)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3584\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3648)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3648\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3712)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3712\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3776)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3776\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3840)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3840\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3904)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3904\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3968)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3968\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4032)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4032\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4096)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4096\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void
|
||||
ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_write_b128 %0, %1 \n \
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %3 %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %4 %1\n \
|
||||
"
|
||||
:
|
||||
: "v"(__to_local(lds)), "v"(r));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
: "=v"(c0), "=v"(c1) // Dest registers
|
||||
: "v"(a), // 1st Src register for 1 half2 registers
|
||||
"v"(b0), // 2nd Src register
|
||||
"v"(b1),
|
||||
"0"(c0), // 3rd Src register
|
||||
"1"(c1));
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %4 %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %6 %1\n \
|
||||
v_dot2_f32_f16 %0, %3, %5 %0\n \
|
||||
v_dot2_f32_f16 %1, %3, %7 %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1) // Dest registers
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"0"(c0),
|
||||
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(half2_t a,
|
||||
half2_t b0,
|
||||
half2_t b1,
|
||||
half2_t b2,
|
||||
half2_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %5 %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %6 %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %7 %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %8 %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
|
||||
: "v"(a), // 1st Src register for 1 half2 registers
|
||||
"v"(b0), // 2nd Src register
|
||||
"v"(b1),
|
||||
"v"(b2),
|
||||
"v"(b3),
|
||||
"0"(c0), // 3rd Src register
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(half4_t a,
|
||||
half4_t b0,
|
||||
half4_t b1,
|
||||
half4_t b2,
|
||||
half4_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %6 %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %8 %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %10 %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %12 %3\n \
|
||||
v_dot2_f32_f16 %0, %5, %7 %0\n \
|
||||
v_dot2_f32_f16 %1, %5, %9 %1\n \
|
||||
v_dot2_f32_f16 %2, %5, %11 %2\n \
|
||||
v_dot2_f32_f16 %3, %5, %13 %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"v"(p_b2_half2[0]),
|
||||
"v"(p_b2_half2[1]),
|
||||
"v"(p_b3_half2[0]),
|
||||
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user