mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
added (1x4)x(2x4) threadwise gemm
This commit is contained in:
@@ -12,17 +12,18 @@
|
||||
namespace ck {
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
|
||||
// MLevel1ThreadCluster, NLevel1ThreadCluster
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
@@ -39,8 +40,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
|
||||
MLevel1ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
@@ -50,8 +51,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
static_assert(
|
||||
@@ -69,26 +70,28 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
|
||||
constexpr index_t MRepeat =
|
||||
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
|
||||
constexpr index_t NRepeat =
|
||||
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
|
||||
|
||||
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
|
||||
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
|
||||
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
|
||||
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
|
||||
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
@@ -99,8 +102,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
@@ -139,8 +144,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
|
||||
@@ -184,6 +191,123 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
|
||||
__device__ void Run_amd_asm_v2(const float* __restrict__ p_a_block,
|
||||
const float* __restrict__ p_b_block,
|
||||
float* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
float p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
float p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MThreadCluster = MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NThreadCluster = NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MDataCluster = M / MPerThreadSubC;
|
||||
constexpr index_t NDataCluster = N / NPerThreadSubC;
|
||||
|
||||
constexpr index_t MRepeat = MDataCluster / MThreadCluster;
|
||||
constexpr index_t NRepeat = NDataCluster / NThreadCluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert((MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1) ||
|
||||
(MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 &&
|
||||
NRepeat == 2 && KPerThreadLoop == 1),
|
||||
"Run_amd_asm cannot deal with this GEMM shape yet");
|
||||
|
||||
static_assert(DataPerReadA == MPerThreadSubC && DataPerReadB == NPerThreadSubC,
|
||||
"wrong! Run_amd_asm doesn't support this config");
|
||||
|
||||
if(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1)
|
||||
{
|
||||
using float4_type = vector_type<float, 4>::MemoryType;
|
||||
|
||||
float4_type* reg_a = reinterpret_cast<float4_type*>(p_a_thread);
|
||||
float4_type* reg_b = reinterpret_cast<float4_type*>(p_b_thread);
|
||||
float4_type* reg_c = reinterpret_cast<float4_type*>(p_c_thread);
|
||||
|
||||
const float4_type* p_a =
|
||||
reinterpret_cast<const float4_type*>(&p_a_block[mMyThreadOffsetA]);
|
||||
const float4_type* p_b =
|
||||
reinterpret_cast<const float4_type*>(&p_b_block[mMyThreadOffsetB]);
|
||||
|
||||
reg_a[0] = p_a[0];
|
||||
reg_b[0] = p_b[0];
|
||||
reg_b[1] = p_b[NThreadCluster];
|
||||
reg_a[1] = p_a[MThreadCluster];
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = p_a[k * MDataCluster];
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = p_b[k * NDataCluster];
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = p_b[k * NDataCluster + NThreadCluster];
|
||||
reg_a[1] = p_a[k * MDataCluster + MThreadCluster];
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
else if(MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1)
|
||||
{
|
||||
using float2_type = vector_type<float, 2>::MemoryType;
|
||||
using float4_type = vector_type<float, 4>::MemoryType;
|
||||
|
||||
float2_type* reg_a = reinterpret_cast<float2_type*>(p_a_thread);
|
||||
float4_type* reg_b = reinterpret_cast<float4_type*>(p_b_thread);
|
||||
float4_type* reg_c = reinterpret_cast<float4_type*>(p_c_thread);
|
||||
|
||||
const float2_type* p_a =
|
||||
reinterpret_cast<const float2_type*>(&p_a_block[mMyThreadOffsetA]);
|
||||
const float4_type* p_b =
|
||||
reinterpret_cast<const float4_type*>(&p_b_block[mMyThreadOffsetB]);
|
||||
|
||||
reg_a[0] = p_a[0];
|
||||
reg_b[0] = p_b[0];
|
||||
reg_b[1] = p_b[NThreadCluster];
|
||||
reg_a[1] = p_a[MThreadCluster];
|
||||
outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]);
|
||||
outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = p_a[k * MDataCluster];
|
||||
outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]);
|
||||
reg_b[0] = p_b[k * NDataCluster];
|
||||
outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]);
|
||||
reg_b[1] = p_b[k * NDataCluster + NThreadCluster];
|
||||
reg_a[1] = p_a[k * MDataCluster + MThreadCluster];
|
||||
outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]);
|
||||
outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]);
|
||||
}
|
||||
outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]);
|
||||
outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
@@ -220,8 +344,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
@@ -273,141 +399,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void RunRegisterDoubleBuffer_source(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
// register
|
||||
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// preload A, B
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_0 + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_0 + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
bool even_loop = true;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
k_begin += KPerThreadLoop, even_loop = !even_loop)
|
||||
{ // loop over k
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
|
||||
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
|
||||
|
||||
// preload next A, B
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
(k_begin + 1) * a_block_mtx.RowStride() +
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_next + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
(k_begin + 1) * b_block_mtx.RowStride() +
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_next + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
}
|
||||
|
||||
// last loop
|
||||
{
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
}
|
||||
}
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
@@ -415,7 +406,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
{
|
||||
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
|
||||
Run_amd_asm_v2(p_a_block, p_b_block, p_c_thread);
|
||||
#else
|
||||
Run_source(p_a_block, p_b_block, p_c_thread);
|
||||
#endif
|
||||
|
||||
@@ -105,6 +105,15 @@ __device__ void outerProduct1x4(const float& a,
|
||||
outerProduct1x4(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c));
|
||||
}
|
||||
|
||||
__device__ void outerProduct2x4(const vector_type<float, 2>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1)
|
||||
{
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
}
|
||||
|
||||
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
|
||||
@@ -60,6 +60,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
|
||||
|
||||
#if 0
|
||||
// each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
@@ -94,20 +95,21 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// each thread hold 32 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 2;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadA = 2;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
@@ -127,74 +129,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 4, 1>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 4, 4>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 4;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 2;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
|
||||
@@ -72,11 +72,11 @@ int main(int argc, char* argv[])
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user