experimenting

[ROCm/composable_kernel commit: 766b0a9eaf]
This commit is contained in:
Chao Liu
2019-03-24 12:09:57 -05:00
parent 6f1651f8a7
commit cd883e7581
33 changed files with 1886 additions and 1822 deletions

View File

@@ -1,30 +1,30 @@
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned BatchPerThread,
unsigned KPerThreadLoop,
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t BatchPerThread,
index_t KPerThreadLoop,
bool DistributeThreadAlongColumnFirst>
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
@@ -61,7 +61,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -72,22 +72,22 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0");
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr index_t BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr index_t MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
"wrong! wrong BlockSize");
@@ -95,10 +95,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
// num of operations can be reduced
const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork);
unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const unsigned m_work_id = itmp / NThreadWork;
const unsigned n_work_id = itmp - m_work_id * NThreadWork;
const index_t b_work_id = thread_id / (MThreadWork * NThreadWork);
index_t itmp = thread_id - b_work_id * (MThreadWork * NThreadWork);
const index_t m_work_id = itmp / NThreadWork;
const index_t n_work_id = itmp - m_work_id * NThreadWork;
return MatrixIndex{
b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread};
@@ -118,7 +118,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
return MatrixIndex{batch_in_c, m_in_c, n_in_c};
}
@@ -138,10 +138,10 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
@@ -154,7 +154,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of a, b
threadwise_matrix_copy(a_block_mtx,
@@ -172,7 +172,7 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
b_thread_mtx.GetLengths());
// loop over batch
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
@@ -226,32 +226,32 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
}
};
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned BlockMatrixStrideA,
unsigned BlockMatrixStrideB,
unsigned ThreadMatrixStrideC,
unsigned BatchSize,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop,
unsigned BatchPerThread>
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned batch;
unsigned row;
unsigned col;
index_t batch;
index_t row;
index_t col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
@@ -259,9 +259,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
@@ -274,31 +274,31 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
@@ -335,28 +335,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster;
unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
unsigned level1_id = cluster_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = cluster_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
@@ -365,24 +365,24 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// this should be optimized away if input is known
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c)
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
@@ -402,10 +402,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
@@ -425,20 +425,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -451,7 +451,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -464,7 +464,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
threadwise_gemm(a_thread_mtx,
@@ -482,7 +482,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -498,7 +498,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -539,10 +539,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
@@ -562,25 +562,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
//#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
#if 1
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
@@ -596,11 +596,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// copy B-sub to form B
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
@@ -612,20 +612,20 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
// loop over batch
//#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
const index_t aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
@@ -635,11 +635,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
@@ -668,11 +668,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideA != 0)
{
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < a_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < a_thread_sub_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] =
@@ -687,11 +687,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
if(BlockMatrixStrideB != 0)
{
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i)
for(index_t i = 0; i < b_thread_sub_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j)
for(index_t j = 0; j < b_thread_sub_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] =
@@ -705,16 +705,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
// do last batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
for(index_t j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_mtx.Get1dIndex(k, j);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
@@ -724,11 +724,11 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex =
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex =
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
asm volatile("\n \
@@ -756,34 +756,34 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const
{
constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned c_thread_offset =
const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
c_thread_sub_mtx,