[ROCm/composable_kernel commit: 1bd880a674]
This commit is contained in:
Chao Liu
2019-04-09 16:13:14 -05:00
parent e99510c4db
commit 4fba215ede
11 changed files with 113 additions and 43 deletions

View File

@@ -16,7 +16,9 @@ template <index_t BlockSize,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread>
index_t BatchPerThread,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t mMyThreadOffsetA = 0;
@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
// copy B-sub to form B
@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// loop over batch
@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
}
@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
}

View File

@@ -14,7 +14,9 @@ template <index_t BlockSize,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop>
index_t KPerThreadLoop,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
bool even_loop = true;
@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B

View File

@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});

View File

@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});

View File

@@ -30,6 +30,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopyThreadPerDims,
index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead,
@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
HoPerThread>{};
HoPerThread,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
#if 1
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");

View File

@@ -26,6 +26,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =

View File

@@ -27,6 +27,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
index_t InBlockCopyThreadPerDim0,
index_t InBlockCopyThreadPerDim1,
index_t WeiBlockCopyThreadPerDim0,
@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop>{};
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS: be careful of alignment
constexpr index_t max_align =

View File

@@ -1,23 +1,29 @@
#pragma once
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol, index_t DataPerRead>
__device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>)
Sequence<NRow, NCol>,
Number<DataPerRead>)
{
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
for(index_t j = 0; j < NCol; j += DataPerRead)
{
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
p_dst[dst_index] = p_src[src_index];
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}