mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
tuned implicit gemm v1 for 3x3 on AMD to 82%. Fixed a bug in 4d tensor blockwise copy.
This commit is contained in:
@@ -646,6 +646,9 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
@@ -664,13 +667,10 @@ struct Blockwise4dTensorCopy3
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
const index_t clipboard_offset = clipboard_desc.Get1dIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[dst_offset])) =
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
@@ -713,6 +713,9 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
@@ -725,11 +728,8 @@ struct Blockwise4dTensorCopy3
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
const index_t clipboard_offset = clipboard_desc.Get1dIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
@@ -738,7 +738,7 @@ struct Blockwise4dTensorCopy3
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[src_offset]));
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,6 +263,94 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
}
|
||||
}
|
||||
|
||||
#if DEVICE_BACKEND_HIP
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
static_assert(
|
||||
BlockMatrixStrideA == 0 && BatchPerThread == 1,
|
||||
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
|
||||
reg_a[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
|
||||
@@ -127,6 +127,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
}
|
||||
|
||||
#if DEVICE_BACKEND_HIP
|
||||
// TODO: this is not working correctly
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
|
||||
@@ -204,23 +204,38 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// input: global mem to LDS
|
||||
#if 1
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
// weight: global mem to LDS
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
#else
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
#pragma unroll
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block +
|
||||
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
#if 1
|
||||
blockwise_batch_gemm.Run
|
||||
#else
|
||||
blockwise_batch_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user