mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
@@ -1,26 +1,26 @@
|
||||
#pragma once
|
||||
#include "threadwise_gemm.hip.hpp"
|
||||
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
unsigned KPerThreadLoop,
|
||||
unsigned MThreadPerCluster,
|
||||
unsigned NThreadPerCluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t MThreadPerCluster,
|
||||
index_t NThreadPerCluster,
|
||||
bool DistributeThreadAlongColumnFirst>
|
||||
struct BlockwiseGemmBlockABlockBThreadC
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadC()
|
||||
@@ -55,7 +55,7 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
@@ -66,14 +66,14 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! k dimension not consistent!");
|
||||
|
||||
constexpr unsigned MPerBlock = a_block_mtx.NCol();
|
||||
constexpr unsigned NPerBlock = b_block_mtx.NCol();
|
||||
constexpr index_t MPerBlock = a_block_mtx.NCol();
|
||||
constexpr index_t NPerBlock = b_block_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
|
||||
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
|
||||
@@ -81,10 +81,10 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
|
||||
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
|
||||
|
||||
constexpr unsigned MClusterWork =
|
||||
constexpr index_t MClusterWork =
|
||||
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
|
||||
|
||||
constexpr unsigned NClusterWork =
|
||||
constexpr index_t NClusterWork =
|
||||
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
|
||||
|
||||
static_assert(BlockSize ==
|
||||
@@ -94,19 +94,18 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
|
||||
if(DistributeThreadAlongColumnFirst)
|
||||
{
|
||||
const unsigned cluster_work_block_id =
|
||||
const index_t cluster_work_block_id =
|
||||
thread_id / (MThreadPerCluster * NThreadPerCluster);
|
||||
|
||||
const unsigned thread_work_cluster_id =
|
||||
const index_t thread_work_cluster_id =
|
||||
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
|
||||
|
||||
const unsigned m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
|
||||
const unsigned n_cluster_work_block_id =
|
||||
const index_t m_cluster_work_block_id = cluster_work_block_id / NClusterWork;
|
||||
const index_t n_cluster_work_block_id =
|
||||
cluster_work_block_id - m_cluster_work_block_id * NClusterWork;
|
||||
|
||||
const unsigned m_thread_work_cluster_id =
|
||||
thread_work_cluster_id / NThreadPerCluster;
|
||||
const unsigned n_thread_work_cluster_id =
|
||||
const index_t m_thread_work_cluster_id = thread_work_cluster_id / NThreadPerCluster;
|
||||
const index_t n_thread_work_cluster_id =
|
||||
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
|
||||
|
||||
#if 0
|
||||
@@ -143,8 +142,8 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
return MatrixIndex{m_in_c, n_in_c};
|
||||
}
|
||||
@@ -164,10 +163,10 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// a is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -180,7 +179,7 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
@@ -213,31 +212,31 @@ struct BlockwiseGemmBlockABlockBThreadC
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
|
||||
template <unsigned BlockSize,
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
unsigned MPerThreadSubC,
|
||||
unsigned NPerThreadSubC,
|
||||
unsigned MLevel0Cluster,
|
||||
unsigned NLevel0Cluster,
|
||||
unsigned MLevel1Cluster,
|
||||
unsigned NLevel1Cluster,
|
||||
unsigned KPerThreadLoop>
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
unsigned row;
|
||||
unsigned col;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
unsigned mMyThreadOffsetA;
|
||||
unsigned mMyThreadOffsetB;
|
||||
index_t mMyThreadOffsetA;
|
||||
index_t mMyThreadOffsetB;
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel1Cluster =
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
@@ -249,31 +248,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr unsigned NPerLevel1Cluster = N / NRepeat;
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
@@ -289,45 +288,45 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id)
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
unsigned level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
unsigned level1_m_id = level1_id / NLevel1Cluster;
|
||||
unsigned level1_n_id = level1_id % NLevel1Cluster;
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
unsigned level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
unsigned level0_m_id = level0_id / NLevel0Cluster;
|
||||
unsigned level0_n_id = level0_id % NLevel0Cluster;
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(unsigned m_in_c,
|
||||
unsigned n_in_c)
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
unsigned m_repeat = m_in_c / MPerThreadSubC;
|
||||
unsigned n_repeat = n_in_c / NPerThreadSubC;
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
unsigned m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
unsigned n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
@@ -346,12 +345,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -370,19 +369,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -395,7 +394,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -433,12 +432,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -457,19 +456,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
//#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
@@ -480,9 +479,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
//#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -505,19 +504,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
False,
|
||||
p_c_thread,
|
||||
f_accum);
|
||||
#else
|
||||
#elif 0
|
||||
// inline asm
|
||||
static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8,
|
||||
"asm is only for 8x8");
|
||||
|
||||
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
|
||||
for(index_t k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed
|
||||
{
|
||||
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
const index_t bindex = b_thread_mtx.Get1dIndex(k, 0);
|
||||
|
||||
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
for(index_t i = 0; i < c_thread_mtx.NRow(); ++i)
|
||||
{
|
||||
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
const index_t aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %8, %9 \n \
|
||||
@@ -573,12 +572,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
@@ -601,15 +600,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// preload A, B
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
@@ -619,7 +618,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
@@ -631,7 +630,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
bool even_loop = true;
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
k_begin += KPerThreadLoop, even_loop = !even_loop)
|
||||
{ // loop over k
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
@@ -642,7 +641,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
// preload next A, B
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
@@ -654,7 +653,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
@@ -710,12 +709,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A-sub, B-sub, C-sub
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
@@ -737,15 +736,15 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// C-sub(s) in first row-wise subblock of C
|
||||
{
|
||||
@@ -779,7 +778,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
#pragma unroll
|
||||
// copy next B-sub, and do GEMM
|
||||
for(unsigned n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 1; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
@@ -805,7 +804,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
#pragma unroll
|
||||
// loop over rest of row-wise subblock
|
||||
// all B-sub(s) has been copied, so only A-sub(s) need to be copied
|
||||
for(unsigned m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
|
||||
for(index_t m_repeat = 1; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
// copy a A-sub
|
||||
threadwise_matrix_copy(
|
||||
@@ -817,7 +816,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
|
||||
// do some GEMMs
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_gemm(
|
||||
a_thread_sub_mtx,
|
||||
|
||||
Reference in New Issue
Block a user