mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// for 3x3, 34x34
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
@@ -105,6 +105,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
@@ -145,7 +147,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 64, 64, 256
|
||||
// 3x3 58x58
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
@@ -166,21 +168,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
constexpr index_t NPerBlock = 8;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 7x7, 38x38
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
@@ -210,9 +197,42 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1, Pacal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 3x3, 56x56
|
||||
// for 3x3, 56x56, v1r2, Pascal
|
||||
// for 3x3, 34x34, v1r2, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -231,6 +251,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 1;
|
||||
constexpr index_t GemmDataPerReadB = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
@@ -317,7 +339,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
#if 1
|
||||
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
@@ -346,6 +368,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
Sequence<InBlockCopy_ThreadPerDimC,
|
||||
InBlockCopy_ThreadPerDimH,
|
||||
InBlockCopy_ThreadPerDimW,
|
||||
|
||||
@@ -205,6 +205,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
@@ -233,6 +235,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr index_t InBlockCopyThreadPerDim1 = 16;
|
||||
@@ -289,6 +293,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
|
||||
@@ -409,7 +409,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
@@ -454,7 +454,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 7x7, 38x38
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
|
||||
@@ -16,7 +16,9 @@ template <index_t BlockSize,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t BatchPerThread>
|
||||
index_t BatchPerThread,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
@@ -220,7 +222,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
// copy B-sub to form B
|
||||
@@ -233,7 +236,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// loop over batch
|
||||
@@ -264,7 +268,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +285,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@ template <index_t BlockSize,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop>
|
||||
index_t KPerThreadLoop,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
@@ -276,7 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -289,7 +292,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
@@ -359,7 +363,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_0 + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -369,7 +374,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_0 + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
bool even_loop = true;
|
||||
@@ -394,7 +400,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_next + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -406,7 +413,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_next + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -169,7 +171,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -172,7 +174,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
@@ -30,6 +30,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
@@ -173,7 +175,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread>{};
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
@@ -185,7 +189,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
// register
|
||||
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc");
|
||||
|
||||
@@ -26,6 +26,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
@@ -174,7 +176,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>{};
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
|
||||
@@ -27,6 +27,8 @@ template <index_t GridSize,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
@@ -178,7 +180,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>{};
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
|
||||
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol, index_t DataPerRead>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
DstMatrix,
|
||||
Float* __restrict__ p_dst,
|
||||
Sequence<NRow, NCol>)
|
||||
Sequence<NRow, NCol>,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto src_mtx = SrcMatrix{};
|
||||
constexpr auto dst_mtx = DstMatrix{};
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < NCol; ++j)
|
||||
for(index_t j = 0; j < NCol; j += DataPerRead)
|
||||
{
|
||||
const index_t src_index = src_mtx.Get1dIndex(i, j);
|
||||
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user