mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
add more assertion
This commit is contained in:
@@ -277,37 +277,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 1x1, 14x14, Pascal, try
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 1;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image, C = 2048
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
@@ -661,9 +661,9 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
|
||||
@@ -229,7 +229,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
|
||||
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
|
||||
|
||||
if(ndim == 2)
|
||||
{
|
||||
@@ -369,4 +369,75 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7));
|
||||
}
|
||||
else if(ndim == 9)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetLength(I5),
|
||||
desc.GetLength(I6),
|
||||
desc.GetLength(I7),
|
||||
desc.GetLength(I8),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7),
|
||||
desc.GetStride(I8));
|
||||
}
|
||||
else if(ndim == 10)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
constexpr auto I9 = Number<9>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetLength(I5),
|
||||
desc.GetLength(I6),
|
||||
desc.GetLength(I7),
|
||||
desc.GetLength(I8),
|
||||
desc.GetLength(I9),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7),
|
||||
desc.GetStride(I8),
|
||||
desc.GetStride(I9));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
// this should be optimized away because input will be known at compile time
|
||||
__device__ static MatrixIndex
|
||||
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
|
||||
{
|
||||
|
||||
@@ -41,7 +41,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
|
||||
static_assert(NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -66,6 +67,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
@@ -218,39 +222,39 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
@@ -261,63 +265,54 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
|
||||
// this is for v2 GEMM
|
||||
// output is a 10d tensor
|
||||
if(NPerThread <= NPerBlock)
|
||||
{
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / (N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / (N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// no implemented yet
|
||||
assert(false);
|
||||
}
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user