experimenting

[ROCm/composable_kernel commit: 766b0a9eaf]
This commit is contained in:
Chao Liu
2019-03-24 12:09:57 -05:00
parent 6f1651f8a7
commit cd883e7581
33 changed files with 1886 additions and 1822 deletions

View File

@@ -1,26 +1,26 @@
#pragma once
#include "threadwise_gemm.hip.hpp"
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned KPerThreadLoop,
unsigned MThreadPerCluster,
unsigned NThreadPerCluster,
index_t KPerThreadLoop,
index_t MThreadPerCluster,
index_t NThreadPerCluster,
bool DistributeThreadAlongColumnFirst>
struct BlockwiseGemmBlockABlockBThreadC
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
__device__ BlockwiseGemmBlockABlockBThreadC()
@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
if(TransA && (!TransB) && (!TransC))
@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
constexpr index_t MPerBlock = a_block_mtx.NCol();
constexpr index_t NPerBlock = b_block_mtx.NCol();
constexpr auto c_thread_mtx = ThreadMatrixC{};
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
constexpr unsigned MClusterWork =
constexpr index_t MClusterWork =
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
constexpr unsigned NClusterWork =
constexpr index_t NClusterWork =
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
static_assert(BlockSize ==
@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
if(DistributeThreadAlongColumnFirst)
{
const unsigned cluster_work_block_id =
const index_t cluster_work_block_id =
thread_id / (MThreadPerCluster * NThreadPerCluster);
const unsigned thread_work_cluster_id =
const index_t thread_work_cluster_id =
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const unsigned n_cluster_work_block_id =
const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
const index_t n_cluster_work_block_id =
cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
const unsigned m_thread_work_cluster_id =
thread_work_cluster_id / NThreadPerCluster;
const unsigned n_thread_work_cluster_id =
const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
const index_t n_thread_work_cluster_id =
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
#if 0
@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
return MatrixIndex{m_in_c, n_in_c};
}
@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
constexpr auto a_thread_mtx =
@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <unsigned BlockSize,
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
unsigned MPerThreadSubC,
unsigned NPerThreadSubC,
unsigned MLevel0Cluster,
unsigned NLevel0Cluster,
unsigned MLevel1Cluster,
unsigned NLevel1Cluster,
unsigned KPerThreadLoop>
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
unsigned row;
unsigned col;
index_t row;
index_t col;
};
unsigned mMyThreadOffsetA;
unsigned mMyThreadOffsetB;
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr unsigned ThreadPerLevel1Cluster =
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
unsigned level1_m_id = level1_id / NLevel1Cluster;
unsigned level1_n_id = level1_id % NLevel1Cluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
unsigned level0_m_id = level0_id / NLevel0Cluster;
unsigned level0_n_id = level0_id % NLevel0Cluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away if input is known
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
unsigned n_in_c)
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
unsigned m_repeat = m_in_c / MPerThreadSubC;
unsigned n_repeat = n_in_c / NPerThreadSubC;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
//#pragma unroll
// copy A-sub to form A
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
}
#pragma unroll
//#pragma unroll
// copy B-sub to form B
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False,
p_c_thread,
f_accum);
#else
#elif 0
// inline asm
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
"asm is only for 8x8");
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
{
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %8, %9 \n \
@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
bool even_loop = true;
#pragma unroll
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// preload next A, B
#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned M = a_block_mtx.NCol();
constexpr unsigned N = b_block_mtx.NCol();
constexpr unsigned K = a_block_mtx.NRow();
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A-sub, B-sub, C-sub
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
#pragma unroll
// loop over k
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
// C-sub(s) in first row-wise subblock of C
{
@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// copy next B-sub, and do GEMM
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll
// loop over rest of row-wise subblock
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
{
// copy a A-sub
threadwise_matrix_copy(
@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
a_thread_sub_mtx.GetLengths());
// do some GEMMs
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_gemm(
a_thread_sub_mtx,