mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
refactor
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle_v2.hpp"
|
||||
#include "device_gemm_xdl_producer_consumer_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
@@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// // 2-stage prefetch
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
#elif 0
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
|
||||
#elif 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// all thread
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 0, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
#elif 0
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// producer & consumer
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
#elif 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdl_cshuffle_v2.hpp"
|
||||
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "gridwise_gemm_xdl_producer_consumer_cshuffle.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -56,10 +56,10 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
struct DeviceGemm_Xdl_CShuffle_v2
|
||||
struct DeviceGemm_Xdl_ProducerConsumer_CShuffle
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemm_Xdl_CShuffle_v2;
|
||||
using DeviceOp = DeviceGemm_Xdl_ProducerConsumer_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2<
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
@@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v2<
|
||||
const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v2<
|
||||
const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemm_Xdl_CShuffle_v2"
|
||||
str << "DeviceGemm_Xdl_ProducerConsumer_CShuffle"
|
||||
<< "<"
|
||||
<< ABBlockTransferThreadGroupSize << ", "
|
||||
<< BlockGemmThreadGroupSize << ", "
|
||||
@@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename ABBlockTransferThreadGroup,
|
||||
typename BlockGemmThreadGroup,
|
||||
index_t NumGemmKPrefetchStage>
|
||||
struct GridwiseGemmPipelineProducerConsumer;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
|
||||
struct GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup, BlockGemmThreadGroup, 1>
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
|
||||
{
|
||||
// TODO: improve applicability
|
||||
return num_loop % 2 == 0;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop / 2 > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep>
|
||||
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_block_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_block_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// move to 1
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write 0
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global Read 1
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write 0
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global Read 1
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write i + 1
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global read i + 2
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write i + 1
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global read i + 2
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
|
||||
// LDS write i + 1
|
||||
// global read i + 2
|
||||
|
||||
// LDS write i + 1
|
||||
// global read i + 2
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_block_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_block_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
if(ABBlockTransferThreadGroup::IsBelong())
|
||||
{
|
||||
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_block_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_copy_step,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_block_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_copy_step,
|
||||
num_loop);
|
||||
}
|
||||
else if(BlockGemmThreadGroup::IsBelong())
|
||||
{
|
||||
RunBlockGemmPipeline<HasMainLoop>(
|
||||
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1>
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
#if 1
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
@@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1>
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#elif 1
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// move to 1
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// LDS write 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global Read 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write 0
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global Read 1
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// main body
|
||||
// FIXME: HasMainLoop = (num_loop) > 2
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write i + 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global read i + 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write i + 1
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global read i + 2
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
|
||||
struct GridwiseGemmPipeline_v2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__device__ constexpr GridwiseGemmPipeline_v2()
|
||||
{
|
||||
// TODO static assert
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
|
||||
{
|
||||
// TODO: improve applicability
|
||||
@@ -23,161 +13,7 @@ struct GridwiseGemmPipeline_v2
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop / 2 > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_block_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_block_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// move to 1
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write 0
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global Read 1
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write 0
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global Read 1
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// main body
|
||||
// FIXME: HasMainLoop = (num_loop) > 2
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write i + 1
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global read i + 2
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write i + 1
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global read i + 2
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
|
||||
// LDS write i + 1
|
||||
// global read i + 2
|
||||
|
||||
// LDS write i + 1
|
||||
// global read i + 2
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
return (num_loop / 2) > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
@@ -195,42 +31,92 @@ struct GridwiseGemmPipeline_v2
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_block_copy,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_block_copy,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
if(ABBlockTransferThreadGroup::IsBelong())
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// move to 1
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// LDS write 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global Read 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write 0
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global Read 1
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_block_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_copy_step,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_block_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_copy_step,
|
||||
num_loop);
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write i + 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global read i + 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write i + 1
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global read i + 2
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
else if(BlockGemmThreadGroup::IsBelong())
|
||||
|
||||
// tail
|
||||
{
|
||||
RunBlockGemmPipeline<HasMainLoop>(
|
||||
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "gridwise_gemm_pipeline_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
#if 1
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
#else
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
|
||||
@@ -7,8 +7,7 @@
|
||||
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "gridwise_gemm_pipeline_v2.hpp"
|
||||
#include "gridwise_gemm_pipeline_producer_consumer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -27,18 +26,20 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdl_cshuffle_v2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_xdl_producer_consumer_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -52,6 +53,18 @@ __global__ void
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename FloatAB,
|
||||
@@ -97,7 +110,7 @@ template <typename FloatAB,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock =
|
||||
ThisThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
|
||||
|
||||
#if 0
|
||||
struct ABBlockTransferThreadGroup
|
||||
{
|
||||
__device__ static constexpr index_t GetNumOfThread()
|
||||
@@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
}
|
||||
};
|
||||
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#else
|
||||
using ABBlockTransferThreadGroup = ThisThreadBlock;
|
||||
using BlockGemmThreadGroup = ThisThreadBlock;
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#endif
|
||||
using CShuffleBlockTransferThreadGroup =
|
||||
ThisThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
|
||||
|
||||
#if 1
|
||||
// gridwise GEMM pipeline
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
#else
|
||||
// gridwise GEMM pipeline
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
BlockGemmThreadGroup,
|
||||
NumGemmKPrefetchStage>;
|
||||
#endif
|
||||
using GridwiseGemmPipe = GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup,
|
||||
BlockGemmThreadGroup,
|
||||
NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N,
|
||||
}
|
||||
|
||||
ConvParams::ConvParams()
|
||||
: num_dim_spatial(2),
|
||||
N(128),
|
||||
K(256),
|
||||
C(192),
|
||||
filter_spatial_lengths(2, 3),
|
||||
input_spatial_lengths(2, 71),
|
||||
conv_filter_strides(2, 2),
|
||||
conv_filter_dilations(2, 1),
|
||||
input_left_pads(2, 1),
|
||||
input_right_pads(2, 1)
|
||||
: num_dim_spatial(2),
|
||||
N(128),
|
||||
K(256),
|
||||
C(192),
|
||||
filter_spatial_lengths(2, 3),
|
||||
input_spatial_lengths(2, 71),
|
||||
conv_filter_strides(2, 2),
|
||||
conv_filter_dilations(2, 1),
|
||||
input_left_pads(2, 1),
|
||||
input_right_pads(2, 1)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim,
|
||||
conv_filter_dilations.size() != num_dim_spatial ||
|
||||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParams::GetOutputSpatialLengths: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
throw(
|
||||
std::runtime_error("ConvParams::GetOutputSpatialLengths: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,9 +91,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
|
||||
conv_filter_dilations.size() != num_dim_spatial ||
|
||||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParams::GetOutputSpatialLengths: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
throw(
|
||||
std::runtime_error("ConvParams::GetOutputSpatialLengths: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> out_spatial_len(num_dim_spatial, 0);
|
||||
@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t idx_eff =
|
||||
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
|
||||
const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
|
||||
out_spatial_len[i] =
|
||||
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
|
||||
conv_filter_strides[i] +
|
||||
|
||||
Reference in New Issue
Block a user