mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -227,9 +227,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
|
||||
|
||||
@@ -491,7 +491,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
|
||||
@@ -94,8 +94,8 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
@@ -164,7 +164,9 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
HoPerThread>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned in_block_size =
|
||||
in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user