mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
ckProfiler and device-level XDL GEMM operator (#48)
* add DeviceGemmXdl * update script * fix naming issue * fix comment * output HostTensorDescriptor * rename * padded GEMM for fwd v4r4r4 nhwc * refactor * refactor * refactor * adding ckProfiler * adding ckProfiler * refactor * fix tuning parameter bug * add more gemm instances * add more fp16 GEMM instances * fix profiler driver * fix bug in tuning parameter * add fp32 gemm instances * small fix * refactor * rename * refactor gemm profiler; adding DeviceConv and conv profiler * refactor * fix * add conv profiler * refactor * adding more GEMM and Conv instance * Create README.md Add build instruction for ckProfiler * Create README.md Add Readme for gemm_xdl example * Update README.md Remove build instruction from top most folder * Update README.md * clean up
This commit is contained in:
@@ -40,6 +40,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
## half
|
||||
#find_path(HALF_INCLUDE_DIR half.hpp)
|
||||
set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/half/include")
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
# CMAKE_CXX_FLAGS
|
||||
@@ -185,6 +186,7 @@ enable_cppcheck(
|
||||
composable_kernel/src/kernel_wrapper
|
||||
INCLUDE
|
||||
host/host_tensor/include
|
||||
host/device/include
|
||||
host/solver/include
|
||||
host/driver_offline/include
|
||||
composable_kernel/include/*
|
||||
@@ -196,3 +198,5 @@ enable_cppcheck(
|
||||
)
|
||||
|
||||
add_subdirectory(host)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
176
README.md
176
README.md
@@ -1,177 +1 @@
|
||||
# How to build and run
|
||||
|
||||
# Docker
|
||||
```
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.2-tf2.4-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
# Install Boost for online compilation
|
||||
https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install
|
||||
|
||||
|
||||
# Build
|
||||
Add path of Boost
|
||||
```
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
```
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
cmake cmd. Need to Specify target ID, example below is gfx908
|
||||
```
|
||||
cmake \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
|
||||
-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
..
|
||||
```
|
||||
|
||||
Build drivers: \
|
||||
``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \
|
||||
``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \
|
||||
``conv_fwd_driver_online`` is (online compilation) driver for forward convolution
|
||||
```
|
||||
make -j conv_fwd_driver_offline
|
||||
make -j conv_bwd_driver_offline
|
||||
make -j conv_fwd_driver_online
|
||||
```
|
||||
|
||||
# Run
|
||||
* layout: 0 = NCHW; 1 = NHWC
|
||||
* algo: algorithm
|
||||
* verify: 0 = no verification; 1 = do verification
|
||||
* init: 0 ~ 5. initialization method
|
||||
* log: 0 = no log; 1 = do log
|
||||
* repeat: number of time kernel being launched
|
||||
```
|
||||
######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
```
|
||||
|
||||
# Result
|
||||
Forward convoltuion, FP16, NCHW
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
|
||||
layout: 0
|
||||
in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1}
|
||||
wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1}
|
||||
out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{216, 256, 8}
|
||||
b_k0_n_k1_grid_desc{216, 165888, 8}
|
||||
c_m_n_grid_desc{ 256, 165888}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.4155 ms, 103.686 TFlop/s
|
||||
```
|
||||
|
||||
Forward convoltuion, FP16, NCHW
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 0
|
||||
in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1}
|
||||
wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1}
|
||||
out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{288, 1024, 8}
|
||||
b_k0_n_k1_grid_desc{288, 50176, 8}
|
||||
c_m_n_grid_desc{ 1024, 50176}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 2.21357 ms, 106.959 TFlop/s
|
||||
```
|
||||
|
||||
Forward convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1}
|
||||
wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1}
|
||||
out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{216, 165888, 8}
|
||||
b_k0_n_k1_grid_desc{216, 256, 8}
|
||||
c_m_n_grid_desc{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.12014 ms, 131.025 TFlop/s
|
||||
```
|
||||
|
||||
Forward convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
|
||||
wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1}
|
||||
out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.86877 ms, 126.693 TFlop/s
|
||||
```
|
||||
|
||||
Backward data convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
|
||||
wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1}
|
||||
out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 2.22461 ms, 106.428 TFlop/s
|
||||
```
|
||||
|
||||
@@ -21,8 +21,7 @@ template <typename... In,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
|
||||
@@ -124,7 +124,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
@@ -136,9 +136,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
@@ -146,24 +146,24 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CMNGridDesc>
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
@@ -175,7 +175,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
@@ -187,8 +187,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
|
||||
static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
|
||||
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
|
||||
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
@@ -202,7 +202,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
@@ -211,7 +211,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
@@ -256,7 +256,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_m2_k1_block_desc),
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
@@ -266,7 +266,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_n2_k1_block_desc),
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<K0, 1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
|
||||
@@ -16,22 +16,23 @@ namespace ck {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
kernel_gemm_xdlops_v2r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -42,19 +43,19 @@ __global__ void
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0N0M1N1M2M3M4N2GridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
|
||||
typename Block2CTileMap>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -62,23 +63,23 @@ __global__ void
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_grid_desc_k0_m_k1,
|
||||
const void CONSTANT* p_b_grid_desc_k0_n_k1,
|
||||
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const void CONSTANT* p_block_2_ctile_map)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
|
||||
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
const auto a_grid_desc_k0_m_k1 = *reinterpret_cast<const AGridDesc_K0_M_K1*>(
|
||||
cast_pointer_to_generic_address_space(p_a_grid_desc_k0_m_k1));
|
||||
const auto b_grid_desc_k0_n_k1 = *reinterpret_cast<const BGridDesc_K0_N_K1*>(
|
||||
cast_pointer_to_generic_address_space(p_b_grid_desc_k0_n_k1));
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
*reinterpret_cast<const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2*>(
|
||||
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
|
||||
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
|
||||
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
@@ -86,10 +87,10 @@ __global__ void
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -98,9 +99,9 @@ template <index_t BlockSize,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
@@ -155,7 +156,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -170,7 +171,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -186,19 +187,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
@@ -209,13 +210,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)))
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
@@ -236,10 +237,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
@@ -254,12 +255,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -274,7 +275,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -292,23 +293,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01)
|
||||
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
@@ -339,31 +340,33 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1));
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using Block2CTileMap = decltype(MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
|
||||
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
@@ -376,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k0_m_k1_block_desc = [&]() {
|
||||
constexpr auto a_block_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -391,7 +394,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
}();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k0_n_k1_block_desc = [&]() {
|
||||
constexpr auto b_block_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
@@ -415,8 +418,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(a_grid_desc_k0_m_k1),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
@@ -426,9 +429,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k0_m_k1_grid_desc,
|
||||
true>(a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
a_block_desc_k0_m_k1,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
@@ -441,8 +444,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
@@ -452,9 +455,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k0_n_k1_grid_desc,
|
||||
true>(b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
b_block_desc_k0_n_k1,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
@@ -469,8 +472,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
decltype(a_block_desc_k0_m_k1),
|
||||
decltype(b_block_desc_k0_n_k1),
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
@@ -481,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
@@ -499,17 +502,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
|
||||
}
|
||||
|
||||
// main body
|
||||
@@ -519,27 +522,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
a_grid_desc_k0_m_k1, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
b_grid_desc_k0_n_k1, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf);
|
||||
|
||||
k0_block_data_begin += K0PerBlock;
|
||||
} while(k0_block_data_begin < (K0 - K0PerBlock));
|
||||
@@ -554,19 +557,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
|
||||
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
|
||||
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
|
||||
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
|
||||
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
|
||||
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
|
||||
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
|
||||
|
||||
@@ -605,8 +608,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
@@ -615,7 +618,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
1,
|
||||
true>{
|
||||
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
@@ -625,10 +628,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2])};
|
||||
|
||||
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
|
||||
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_grid_buf,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
|
||||
}
|
||||
|
||||
@@ -304,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
NRepeat,
|
||||
K1>;
|
||||
|
||||
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc);
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
@@ -596,7 +596,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
|
||||
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
|
||||
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
|
||||
|
||||
@@ -644,17 +644,17 @@ struct XdlopsGemm
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
|
||||
}
|
||||
|
||||
template <typename CM0N0M1N1M2N2Desc>
|
||||
template <typename CDesc_M0_N0_M1_N1_M2_N2>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc)
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
|
||||
const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
|
||||
const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
|
||||
const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_m0_n0_m1_n1_m2_n2_desc,
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
|
||||
@@ -94,7 +94,7 @@
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
|
||||
@@ -16,6 +16,9 @@ struct is_same<X, X> : public integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_same_v = is_same<X, Y>::value;
|
||||
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
@@ -230,14 +230,14 @@ extern "C" __global__ void
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_conv_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple<
|
||||
// clang-format off
|
||||
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_conv_fwd_instance<2, F16, F16, F16, NHWC, KYXC, NHWK>(
|
||||
std::vector<DeviceConvFwdPtr>& device_conv_instances)
|
||||
{
|
||||
using DeviceConvs = device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk;
|
||||
|
||||
const auto device_convs = DeviceConvs{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceConvs>, 1>{}([&](auto i) {
|
||||
using Conv = remove_cvref_t<decltype(std::get<i>(device_convs))>;
|
||||
|
||||
auto conv = Conv{};
|
||||
|
||||
device_conv_instances.push_back(std::make_unique<Conv>(conv));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_conv_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple<
|
||||
// clang-format off
|
||||
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_conv_fwd_instance<2, F32, F32, F32, NHWC, KYXC, NHWK>(
|
||||
std::vector<DeviceConvFwdPtr>& device_conv_instances)
|
||||
{
|
||||
using DeviceConvs = device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk;
|
||||
|
||||
const auto device_convs = DeviceConvs{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceConvs>, 1>{}([&](auto i) {
|
||||
using Conv = remove_cvref_t<decltype(std::get<i>(device_convs))>;
|
||||
|
||||
auto conv = Conv{};
|
||||
|
||||
device_conv_instances.push_back(std::make_unique<Conv>(conv));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Col, Row, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Col, Col, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Row, Row, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F16, F16, F16, Row, Col, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Col, Row, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Col, Col, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Row, Row, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
|
||||
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
|
||||
// clang-format off
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
|
||||
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>(
|
||||
std::vector<DeviceGemmPtr>& device_op_instances)
|
||||
{
|
||||
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn;
|
||||
|
||||
const auto device_gemms = DeviceGemms{};
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
|
||||
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
|
||||
|
||||
auto gemm = Gemm{};
|
||||
|
||||
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
42
device_operation/include/device_base.hpp
Normal file
42
device_operation/include/device_base.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef DEVICE_BASE_HPP
|
||||
#define DEVICE_BASE_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct BaseArgument
|
||||
{
|
||||
BaseArgument() = default;
|
||||
BaseArgument(const BaseArgument&) = default;
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
};
|
||||
|
||||
struct BaseInvoker
|
||||
{
|
||||
BaseInvoker() = default;
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker& operator=(const BaseInvoker&) = default;
|
||||
|
||||
virtual float Run(const BaseArgument*, int = 1) = 0;
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
|
||||
struct BaseOperator
|
||||
{
|
||||
BaseOperator() = default;
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
78
device_operation/include/device_conv.hpp
Normal file
78
device_operation/include/device_conv.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
#ifndef DEVICE_CONV_HPP
|
||||
#define DEVICE_CONV_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DeviceConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
struct DeviceConvBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* p_in,
|
||||
const void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
struct DeviceConvWrw : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
using DeviceConvFwdPtr = std::unique_ptr<DeviceConvFwd>;
|
||||
using DeviceConvBwdPtr = std::unique_ptr<DeviceConvBwd>;
|
||||
using DeviceConvWrwPtr = std::unique_ptr<DeviceConvWrw>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
58
device_operation/include/device_conv_fwd_xdl.hpp
Normal file
58
device_operation/include/device_conv_fwd_xdl.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#ifndef DEVICE_CONV_FWD_XDL_HPP
|
||||
#define DEVICE_CONV_FWD_XDL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceConvFwdXdl;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
601
device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp
Normal file
601
device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp
Normal file
@@ -0,0 +1,601 @@
|
||||
#ifndef DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "device_conv_fwd_xdl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// specialization for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceConvFwdXdl<
|
||||
2, // ck::index_t NDimSpatial,
|
||||
InDataType, // typename InDataType,
|
||||
WeiDataType, // typename WeiDataType,
|
||||
OutDataType, // typename OutDataType,
|
||||
AccDataType, // typename AccDataType,
|
||||
ck::tensor_layout::convolution::NHWC, // typename InLayout,
|
||||
ck::tensor_layout::convolution::KYXC, // typename WeiLayout,
|
||||
ck::tensor_layout::convolution::NHWK, // typename OutLayout,
|
||||
BlockSize, // ck::index_t BlockSize,
|
||||
MPerBlock, // ck::index_t MPerBlock,
|
||||
NPerBlock, // ck::index_t NPerBlock,
|
||||
K0PerBlock, // ck::index_t K0PerBlock,
|
||||
K1, // ck::index_t K1,
|
||||
MPerXDL, // ck::index_t MPerXDL,
|
||||
NPerXDL, // ck::index_t NPerXDL,
|
||||
MXdlPerWave, // ck::index_t MXdlPerWave,
|
||||
NXdlPerWave, // ck::index_t NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1, // typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1, // typename
|
||||
// ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder, // typename ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder, // typename ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim, // ck::index_t ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector, // ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1, // ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1, // typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1, // typename
|
||||
// BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder, // typename BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder, // typename BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim, // ck::index_t BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector, // ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1, // ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
CThreadTransferSrcDstVectorDim, // ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector,
|
||||
ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN // bool BBlockLdsAddExtraN>
|
||||
> : public DeviceConvFwd
|
||||
{
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
const index_t GemmK = Y * X * C;
|
||||
|
||||
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw;
|
||||
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// TODO remove these hacks
|
||||
static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
|
||||
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
|
||||
false, // CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01}
|
||||
{
|
||||
const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceConvFwdXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
42
device_operation/include/device_conv_instance.hpp
Normal file
42
device_operation/include/device_conv_instance.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#ifndef DEVICE_CONV_INSTANTCE_HPP
|
||||
#define DEVICE_CONV_INSTANTCE_HPP
|
||||
|
||||
#include "device_conv.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv_instance {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_fwd_instance(std::vector<DeviceConvFwdPtr>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_bwd_instance(std::vector<DeviceConvBwdPtr>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_wrw_instance(std::vector<DeviceConvWrwPtr>&);
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
31
device_operation/include/device_gemm.hpp
Normal file
31
device_operation/include/device_gemm.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef DEVICE_GEMM_HPP
|
||||
#define DEVICE_GEMM_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DeviceGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
using DeviceGemmPtr = std::unique_ptr<DeviceGemm>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
23
device_operation/include/device_gemm_instance.hpp
Normal file
23
device_operation/include/device_gemm_instance.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef DEVICE_GEMM_INSTANTCE_HPP
|
||||
#define DEVICE_GEMM_INSTANTCE_HPP
|
||||
|
||||
#include "device_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void add_device_gemm_instance(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
442
device_operation/include/device_gemm_xdl.hpp
Normal file
442
device_operation/include/device_gemm_xdl.hpp
Normal file
@@ -0,0 +1,442 @@
|
||||
#ifndef DEVICE_GEMM_XDL_HPP
|
||||
#define DEVICE_GEMM_XDL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceGemmXdl : public DeviceGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_k0_m_k1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_k0_m_k1;
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_k0_n_k1 =
|
||||
transform_tensor_descriptor(b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_k0_n_k1;
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// TODO remove these hacks
|
||||
static constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
|
||||
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
|
||||
false, // CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -13,4 +13,10 @@ enum GemmMatrixLayout
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
enum GemmDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
};
|
||||
|
||||
#endif
|
||||
52
device_operation/include/tensor_layout.hpp
Normal file
52
device_operation/include/tensor_layout.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef TENSOR_LAYOUT_HPP
|
||||
#define TENSOR_LAYOUT_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_layout {
|
||||
|
||||
struct BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
namespace gemm {
|
||||
|
||||
struct RowMajor : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct ColumnMajor : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
} // namespace gemm
|
||||
|
||||
namespace convolution {
|
||||
|
||||
struct NHWC : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct KYXC : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct NHWK : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct NCHW : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct KCYX : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
struct NKHW : public BaseTensorLayout
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace convolution
|
||||
|
||||
} // namespace tensor_layout
|
||||
} // namespace ck
|
||||
#endif
|
||||
56
example/1_gemm_xdl/README.md
Normal file
56
example/1_gemm_xdl/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Instructions for ```gemm_xdl``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ``gemm_xdl```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j gemm_xdl
|
||||
```
|
||||
|
||||
## Run ```gemm_xdl```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
./example/gemm_xdl.sh 0 1 5
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
|
||||
```
|
||||
202
example/1_gemm_xdl/gemm_xdl.cpp
Normal file
202
example/1_gemm_xdl/gemm_xdl.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceGemmInstance;
|
||||
|
||||
template <>
|
||||
struct DeviceGemmInstance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>
|
||||
{
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for NT problem
|
||||
// clang-format off
|
||||
using type =
|
||||
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DeviceGemmInstance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>
|
||||
{
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// Compilation parameters for NT problem
|
||||
// clang-format off
|
||||
using type =
|
||||
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const bool do_verification = std::stoi(argv[1]);
|
||||
const int init_method = std::stoi(argv[2]);
|
||||
const int nrepeat = std::stoi(argv[3]);
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
// matrix data type
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
// matrix layout
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<BDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// do GEMM
|
||||
auto gemm =
|
||||
typename DeviceGemmInstance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>::
|
||||
type{};
|
||||
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result);
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
}
|
||||
}
|
||||
18
example/CMakeLists.txt
Normal file
18
example/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/device/include
|
||||
${PROJECT_SOURCE_DIR}/device_operation/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
set(GEMM_XDL_SOURCE 1_gemm_xdl/gemm_xdl.cpp)
|
||||
|
||||
add_executable(gemm_xdl ${GEMM_XDL_SOURCE})
|
||||
|
||||
target_link_libraries(gemm_xdl PRIVATE host_tensor)
|
||||
5670
external/half/include/half.hpp
vendored
Normal file
5670
external/half/include/half.hpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/device/include
|
||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
|
||||
@@ -141,14 +141,14 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
|
||||
@@ -4,6 +4,131 @@
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto
|
||||
MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1,
|
||||
const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1,
|
||||
const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw)
|
||||
{
|
||||
const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0);
|
||||
const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2);
|
||||
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw;
|
||||
const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw;
|
||||
const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw;
|
||||
|
||||
// A
|
||||
const auto a_grid_desc_k0_m_k1 = [&]() {
|
||||
if constexpr(DoPad_K0 && DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(DoPad_K0 && !DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_pass_through_transform(MRaw),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_K0 && DoPad_M)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_pass_through_transform(K0Raw),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return a_grid_desc_k0raw_mraw_k1;
|
||||
}
|
||||
}();
|
||||
|
||||
// B
|
||||
const auto b_grid_desc_k0_n_k1 = [&]() {
|
||||
if constexpr(DoPad_K0 && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(DoPad_K0 && !DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
|
||||
make_pass_through_transform(NRaw),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_K0 && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_pass_through_transform(K0Raw),
|
||||
make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_grid_desc_k0raw_nraw_k1;
|
||||
}
|
||||
}();
|
||||
|
||||
// C
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(DoPad_M && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(DoPad_M && !DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(!DoPad_M && DoPad_N)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
reutnr c_grid_desc_m_n;
|
||||
}
|
||||
}();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
@@ -275,20 +400,19 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
#if 0 // debug
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
// HACK: hacks that control index calculation when iterating over A matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||
@@ -297,7 +421,39 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
#else
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = descs[I0];
|
||||
|
||||
const auto GemmK0 = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I0);
|
||||
const auto GemmMRaw = in_gemmk0_gemmmraw_gemmk1_grid_desc.GetLength(I1);
|
||||
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, GemmMPerBlock) - GemmMRaw;
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
@@ -305,6 +461,12 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
#if 0
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
@@ -322,12 +484,36 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
#else
|
||||
const auto out_gemmmraw_gemmn_grid_desc = descs[I2];
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
const auto GemmN = out_gemmmraw_gemmn_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
#endif
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
|
||||
@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -52,9 +52,9 @@ template <ck::index_t BlockSize,
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K& b_grid_desc_k0_n_k1,
|
||||
const CMNGridDesc& c_grid_desc_m_n,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
AGridStepHacks,
|
||||
@@ -63,7 +63,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -77,8 +76,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
@@ -117,38 +116,37 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_k0_m_k1.GetLength(I1) << ", " << a_grid_desc_k0_m_k1.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
std::cout << "b_grid_desc_k0_n_k1{" << b_grid_desc_k0_n_k1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_k0_n_k1.GetLength(I1) << ", " << b_grid_desc_k0_n_k1.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "c_grid_desc_m_n{ " << c_grid_desc_m_n.GetLength(I0) << ", "
|
||||
<< c_grid_desc_m_n.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01))
|
||||
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, M01, N01))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
|
||||
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
|
||||
|
||||
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01);
|
||||
const auto block_2_ctile_map = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n, M01, N01);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
using Block2CTileMap = decltype(block_2_ctile_map);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
@@ -157,14 +155,15 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -174,21 +173,22 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -198,32 +198,34 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
DeviceMem a_grid_desc_k0_m_k1_dev_buf(sizeof(AGridDesc_K0_M_K1));
|
||||
DeviceMem b_grid_desc_k0_n_k1_dev_buf(sizeof(BGridDesc_K0_N_K));
|
||||
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(
|
||||
sizeof(CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2));
|
||||
DeviceMem block_2_ctile_map_dev_buf(sizeof(Block2CTileMap));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
a_grid_desc_k0_m_k1_dev_buf.ToDevice(&a_grid_desc_k0_m_k1);
|
||||
b_grid_desc_k0_n_k1_dev_buf.ToDevice(&b_grid_desc_k0_n_k1);
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
block_2_ctile_map_dev_buf.ToDevice(&block_2_ctile_map);
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
true>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
@@ -234,23 +236,23 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(a_grid_desc_k0_m_k1_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_grid_desc_k0_n_k1_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
cast_pointer_to_constant_address_space(block_2_ctile_map_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>,
|
||||
false>;
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
@@ -261,12 +263,11 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(a_grid_desc_k0_m_k1_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_grid_desc_k0_n_k1_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
cast_pointer_to_constant_address_space(block_2_ctile_map_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_data.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -21,12 +20,153 @@
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
enum ConvBackwardDataAlgo
|
||||
{
|
||||
V4R1XDLNHWC, // 0
|
||||
V4R1R2XDLNHWC, // 1
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_backward_data(Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& /* in_right_pads */,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I2];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I2];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I3];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, k, ho, wo) * wei(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I1];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I2];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I1];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, ho, wo, k) * wei(k, y, x, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, hi, wi, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -324,14 +464,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_data(in_host,
|
||||
wei,
|
||||
out,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
host_convolution_backward_data(in_host,
|
||||
wei,
|
||||
out,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(in_host, in_device);
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -28,6 +27,15 @@
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
@@ -38,6 +46,93 @@ enum ConvForwardAlgo
|
||||
V4R4R4XDLNHWC // 5
|
||||
};
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_forward(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -425,14 +520,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
host_convolution_forward(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_weight.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -19,6 +18,15 @@
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
|
||||
@@ -35,6 +43,92 @@ enum ConvBackwardWeightAlgo
|
||||
V4R4R5XDLATOMICNHWC, // 4
|
||||
};
|
||||
|
||||
template <typename TOut,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_convolution_backward_weight(const Tensor<TOut>& out,
|
||||
const Tensor<TIn>& in,
|
||||
Tensor<TWei>& wei,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(out(n, k, ho, wo));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, c, y, x) = v;
|
||||
};
|
||||
|
||||
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(out(n, ho, wo, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, y, x, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kcyx,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kyxc,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -414,14 +508,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_weights(out,
|
||||
in,
|
||||
wei_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
host_convolution_backward_weight(out,
|
||||
in,
|
||||
wei_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(wei_host, wei_device);
|
||||
|
||||
|
||||
@@ -3,15 +3,6 @@
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "conv_common.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
@@ -8,19 +9,16 @@ template <typename TIn,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I0 = ck::Number<0>{};
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
@@ -44,281 +42,9 @@ void host_direct_convolution(const Tensor<TIn>& in,
|
||||
out(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads>
|
||||
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
InLeftPads,
|
||||
InRightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr std::size_t HoPerTile = 2;
|
||||
constexpr std::size_t WoPerTile = 2;
|
||||
|
||||
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
||||
|
||||
std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
|
||||
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
||||
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t Ho = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile;
|
||||
std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile;
|
||||
|
||||
Tensor<double> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<double> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<double> wei_transform({K, C, HiPerTile, WiPerTile});
|
||||
Tensor<double> out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile});
|
||||
Tensor<double> out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile});
|
||||
|
||||
auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
int hi = HoPerTile * htile + j - h_pad_low;
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
int wi = WoPerTile * wtile + i - w_pad_low;
|
||||
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
{
|
||||
in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi);
|
||||
}
|
||||
else
|
||||
{
|
||||
in_hold(n, c, htile, wtile, j, i) = TIn(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
in_transform(n, c, htile, wtile, 0, 0) =
|
||||
in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 1) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 2) =
|
||||
-in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 3) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 1, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 2, 0) =
|
||||
-in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 1) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 2) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 3) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 3, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_wei_transform = [&](auto k, auto c) {
|
||||
wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0));
|
||||
wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 2));
|
||||
wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 2));
|
||||
wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2));
|
||||
|
||||
wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 1, 1) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 1, 2) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) -
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
|
||||
wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 2, 1) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) -
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 2, 2) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) -
|
||||
0.5 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
|
||||
wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2));
|
||||
};
|
||||
|
||||
auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
double v = 0;
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i);
|
||||
}
|
||||
|
||||
out_transform(n, k, htile, wtile, j, i) = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
out_hold(n, k, htile, wtile, 0, 0) =
|
||||
out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) +
|
||||
out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) +
|
||||
out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) +
|
||||
out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2);
|
||||
out_hold(n, k, htile, wtile, 0, 1) =
|
||||
out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) -
|
||||
out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) -
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) +
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 2, 3);
|
||||
out_hold(n, k, htile, wtile, 1, 0) =
|
||||
out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) +
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) -
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) -
|
||||
out_transform(n, k, htile, wtile, 3, 2);
|
||||
out_hold(n, k, htile, wtile, 1, 1) =
|
||||
out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) -
|
||||
out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) -
|
||||
out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) +
|
||||
out_transform(n, k, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_out = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HoPerTile; ++j)
|
||||
{
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& /* in_right_pads */,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I2];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I2];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I3];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, k, ho, wo) * wei(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I1];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I2];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I1];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, ho, wo, k) * wei(k, y, x, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, hi, wi, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TOut,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_backward_weights(
|
||||
const Tensor<TOut>& out,
|
||||
const Tensor<TIn>& in,
|
||||
Tensor<TWei>& wei,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(out(n, k, ho, wo));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, c, y, x) = v;
|
||||
};
|
||||
|
||||
auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
double v = 0;
|
||||
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
|
||||
{
|
||||
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(out(n, ho, wo, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
wei(k, y, x, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kcyx,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_kyxc,
|
||||
wei.mDesc.GetLengths()[0],
|
||||
wei.mDesc.GetLengths()[1],
|
||||
wei.mDesc.GetLengths()[2],
|
||||
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
@@ -157,3 +157,26 @@ void host_gemm(const Tensor<AType>& a,
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
|
||||
const Tensor<BType>& b_k_n,
|
||||
Tensor<CType>& c_m_n)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a_m_k.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
|
||||
}
|
||||
|
||||
c_m_n(m, n) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn,
|
||||
c_m_n.mDesc.GetLengths()[0],
|
||||
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
@@ -120,6 +120,8 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
std::vector<std::size_t> mStrides;
|
||||
@@ -224,7 +226,7 @@ struct Tensor
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
|
||||
|
||||
template <typename G>
|
||||
void GenerateTensorValue(G g, std::size_t num_thread = 1)
|
||||
void GenerateTensorValue(G g, std::size_t num_thread = std::thread::hardware_concurrency())
|
||||
{
|
||||
switch(mDesc.GetNumOfDimension())
|
||||
{
|
||||
|
||||
@@ -34,6 +34,21 @@ const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { retur
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
50
profiler/CMakeLists.txt
Normal file
50
profiler/CMakeLists.txt
Normal file
@@ -0,0 +1,50 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/device/include
|
||||
${PROJECT_SOURCE_DIR}/device_operation/include
|
||||
${PROJECT_SOURCE_DIR}/profiler/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
# device_gemm_instance
|
||||
set(DEVICE_GEMM_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
|
||||
)
|
||||
|
||||
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
|
||||
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_compile_features(device_gemm_instance PUBLIC)
|
||||
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
|
||||
|
||||
# device_conv_instance
|
||||
set(DEVICE_CONV_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
|
||||
)
|
||||
|
||||
add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE})
|
||||
target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_compile_features(device_conv_instance PUBLIC)
|
||||
set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
install(TARGETS device_conv_instance LIBRARY DESTINATION lib)
|
||||
|
||||
# ck_profiler
|
||||
set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp)
|
||||
add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
|
||||
target_link_libraries(ckProfiler PRIVATE host_tensor)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance)
|
||||
81
profiler/README.md
Normal file
81
profiler/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```ckProfiler```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to Specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j ckProfiler
|
||||
```
|
||||
|
||||
## Profile GEMM kernels
|
||||
```bash
|
||||
#arg1: tensor operation (gemm=GEMM)
|
||||
#arg2: data type (0=fp32, 1=fp16)
|
||||
#arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)
|
||||
#arg4: verification (0=no, 1=yes)
|
||||
#arg5: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg6: print matrix value (0=no, 1=yes)
|
||||
#arg7: run kernel # of times (>1)
|
||||
#arg8 to 13: M, N, K, StrideA, StrideB, StrideC
|
||||
|
||||
##################### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC
|
||||
./profiler/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```bash
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
....
|
||||
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
|
||||
```
|
||||
|
||||
## Profile forward convolution kernels
|
||||
```bash
|
||||
#arg1: tensor operation (conv=Convolution)
|
||||
#arg2: data type (0=fp32, 1=fp16)
|
||||
#arg3: input tensor layout (0=NCHW, 1=NHWC)
|
||||
#arg4: weight tensor layout (0=KCYX, 1=KYXC)
|
||||
#arg5: output tensor layout (0=NKHW, 1=NHWK)
|
||||
#arg6: verification (0=no, 1=yes)
|
||||
#arg7: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg8: print matrix value (0=no, 1=yes)
|
||||
#arg9: run kernel # of times (>1)
|
||||
#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
|
||||
##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
|
||||
./profiler/ckProfiler conv 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
|
||||
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
....
|
||||
Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
|
||||
```
|
||||
139
profiler/conv_profiler.cpp
Normal file
139
profiler/conv_profiler.cpp
Normal file
@@ -0,0 +1,139 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "profile_conv.hpp"
|
||||
|
||||
enum ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
};
|
||||
|
||||
enum ConvInputLayout
|
||||
{
|
||||
NCHW, // 0
|
||||
NHWC, // 1
|
||||
};
|
||||
|
||||
enum ConvWeightLayout
|
||||
{
|
||||
KCYX, // 0
|
||||
KYXC, // 1
|
||||
};
|
||||
|
||||
enum ConvOutputLayout
|
||||
{
|
||||
NKHW, // 0
|
||||
NHWK, // 1
|
||||
};
|
||||
|
||||
int conv_profiler(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 25)
|
||||
{
|
||||
printf("arg1: tensor operation (conv=Convolution)\n");
|
||||
printf("arg2: data type (0=fp32, 1=fp16)\n");
|
||||
printf("arg3: input tensor layout (0=NCHW, 1=NHWC)\n");
|
||||
printf("arg4: weight tensor layout (0=KCYX, 1=KYXC)\n");
|
||||
printf("arg5: output tensor layout (0=NKHW, 1=NHWK)\n");
|
||||
printf("arg6: verification (0=no, 1=yes)\n");
|
||||
printf("arg7: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg8: print matrix value (0=no, 1=yes)\n");
|
||||
printf("arg9: run kernel # of times (>1)\n");
|
||||
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
|
||||
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
|
||||
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
|
||||
const bool do_verification = std::stoi(argv[6]);
|
||||
const int init_method = std::stoi(argv[7]);
|
||||
const bool do_log = std::stoi(argv[8]);
|
||||
const int nrepeat = std::stoi(argv[9]);
|
||||
|
||||
const ck::index_t N = std::stoi(argv[10]);
|
||||
const ck::index_t K = std::stoi(argv[11]);
|
||||
const ck::index_t C = std::stoi(argv[12]);
|
||||
const ck::index_t Y = std::stoi(argv[13]);
|
||||
const ck::index_t X = std::stoi(argv[14]);
|
||||
const ck::index_t Hi = std::stoi(argv[15]);
|
||||
const ck::index_t Wi = std::stoi(argv[16]);
|
||||
|
||||
const ck::index_t conv_stride_h = std::stoi(argv[17]);
|
||||
const ck::index_t conv_stride_w = std::stoi(argv[18]);
|
||||
const ck::index_t conv_dilation_h = std::stoi(argv[19]);
|
||||
const ck::index_t conv_dilation_w = std::stoi(argv[20]);
|
||||
const ck::index_t in_left_pad_h = std::stoi(argv[21]);
|
||||
const ck::index_t in_left_pad_w = std::stoi(argv[22]);
|
||||
const ck::index_t in_right_pad_h = std::stoi(argv[23]);
|
||||
const ck::index_t in_right_pad_w = std::stoi(argv[24]);
|
||||
|
||||
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
|
||||
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
ck::profiler::profile_conv<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
std::vector<ck::index_t>{Hi, Wi},
|
||||
std::vector<ck::index_t>{Y, X},
|
||||
std::vector<ck::index_t>{Ho, Wo},
|
||||
std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
|
||||
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
|
||||
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
|
||||
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
|
||||
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
|
||||
{
|
||||
ck::profiler::profile_conv<2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
std::vector<ck::index_t>{Hi, Wi},
|
||||
std::vector<ck::index_t>{Y, X},
|
||||
std::vector<ck::index_t>{Ho, Wo},
|
||||
std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
|
||||
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
|
||||
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
|
||||
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this Conv data_type & layout is not implemented");
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
135
profiler/gemm_profiler.cpp
Normal file
135
profiler/gemm_profiler.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "profile_gemm.hpp"
|
||||
|
||||
int gemm_profiler(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 14)
|
||||
{
|
||||
printf("arg1: tensor operation (gemm=GEMM)\n");
|
||||
printf("arg2: data type (0=fp32, 1=fp16)\n");
|
||||
printf("arg3: matrix layout (0=NN, 1=NT, 2=TN, 3=TT)\n");
|
||||
printf("arg4: verification (0=no, 1=yes)\n");
|
||||
printf("arg5: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg6: print matrix value (0=no, 1=yes)\n");
|
||||
printf("arg7: run kernel # of times (>1)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const int nrepeat = std::stoi(argv[7]);
|
||||
|
||||
const int M = std::stoi(argv[8]);
|
||||
const int N = std::stoi(argv[9]);
|
||||
const int K = std::stoi(argv[10]);
|
||||
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification, init_method, do_log, nrepeat, M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
229
profiler/include/profile_conv.hpp
Normal file
229
profiler/include/profile_conv.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#pragma once
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "device_conv_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv_instance {
|
||||
|
||||
template <>
|
||||
void add_device_conv_fwd_instance<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_conv_fwd_instance<2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK>(
|
||||
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr>&);
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <int NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void profile_conv(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
int nrepeat,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
const ck::index_t Y = filter_spatial_lengths[0];
|
||||
const ck::index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const ck::index_t Hi = input_spatial_lengths[0];
|
||||
const ck::index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const ck::index_t Ho = output_spatial_lengths[0];
|
||||
const ck::index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) {
|
||||
if constexpr(is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
|
||||
}
|
||||
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
|
||||
Tensor<OutDataType> out_n_k_ho_wo_host_result(
|
||||
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
|
||||
Tensor<OutDataType> out_n_k_ho_wo_device_result(
|
||||
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
break;
|
||||
default:
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_conv_nchw_kcyx_nkhw(in_n_c_hi_wi,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo_host_result,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) *
|
||||
out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
|
||||
// add device Conv instances
|
||||
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr> conv_ptrs;
|
||||
|
||||
ck::tensor_operation::device::device_conv_instance::add_device_conv_fwd_instance<2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
conv_ptrs);
|
||||
|
||||
if(conv_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device Conv instance found");
|
||||
}
|
||||
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device Conv instances
|
||||
for(auto& conv_ptr : conv_ptrs)
|
||||
{
|
||||
auto argument_ptr = conv_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
auto invoker_ptr = conv_ptr->MakeInvokerPointer();
|
||||
|
||||
if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
|
||||
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
|
||||
sizeof(WeiDataType) * (K * C * Y * X) +
|
||||
sizeof(OutDataType) * (N * K * Ho * Wo);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
|
||||
|
||||
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei_k_c_y_x.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
229
profiler/include/profile_gemm.hpp
Normal file
229
profiler/include/profile_gemm.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#pragma once
|
||||
#include "device_gemm_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
template <>
|
||||
void add_device_gemm_instance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void profile_gemm(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
int nrepeat,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::DeviceGemmPtr> gemm_ptrs;
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
gemm_ptrs);
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this device GEMM instance does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
26
profiler/profiler.cpp
Normal file
26
profiler/profiler.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
int gemm_profiler(int, char*[]);
|
||||
int conv_profiler(int, char*[]);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(strcmp(argv[1], "gemm") == 0)
|
||||
{
|
||||
return gemm_profiler(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv") == 0)
|
||||
{
|
||||
return conv_profiler(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: tensor operation (gemm=GEMM, conv=Convolution)\n");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,11 @@ MY_PROJECT_INSTALL=../install.dir
|
||||
|
||||
cmake \
|
||||
-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \
|
||||
-D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \
|
||||
-D BUILD_DEV=ON \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 -ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
|
||||
71
script/conv_driver.sh
Executable file
71
script/conv_driver.sh
Executable file
@@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
|
||||
make -j conv_fwd_driver_offline
|
||||
#make -j conv_bwd_driver_offline
|
||||
#make -j conv_wrw_driver_offline
|
||||
|
||||
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
|
||||
#DRIVER="./host/driver_offline/conv_bwd_driver_offline"
|
||||
#DRIVER="./host/driver_offline/conv_wrw_driver_offline"
|
||||
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
DESIRED_GRID_SIZE=$7
|
||||
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
$DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
$DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
$DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
$DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 3 3 1 1 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 32 256 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
$DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 1 1 2 2 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
|
||||
# Resnet50
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE
|
||||
##DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
|
||||
20
script/example_gemm_xdl.sh
Executable file
20
script/example_gemm_xdl.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=1
|
||||
|
||||
make -j gemm_xdl
|
||||
|
||||
DRIVER="./example/gemm_xdl"
|
||||
|
||||
VERIFY=$1
|
||||
INIT=$2
|
||||
LOG=$3
|
||||
REPEAT=$4
|
||||
|
||||
######### verify init log repeat M___ N___ K___ StrideA StrideB StrideC
|
||||
#$DRIVER $VERIFY $INIT $LOG $REPEAT 960 1024 1024 1024 1024 1024
|
||||
#$DRIVER $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
|
||||
#$DRIVER $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 2048 2048 2048
|
||||
$DRIVER $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 4096 4096 4096
|
||||
#$DRIVER $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 8192 8192 8192
|
||||
25
script/gemm_driver.sh
Executable file
25
script/gemm_driver.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
|
||||
make -j gemm_driver_offline
|
||||
|
||||
DRIVER="./host/driver_offline/gemm_driver_offline"
|
||||
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
M01=$7
|
||||
N01=$8
|
||||
|
||||
######### layout algo verify init log repeat M___ N___ K___ M01_ N01_
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
|
||||
137
script/run.sh
137
script/run.sh
@@ -1,137 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export ROCR_VISIBLE_DEVICE=0
|
||||
export GPU_DEVICE_ORDINAL=0
|
||||
|
||||
make -j conv_fwd_driver_offline
|
||||
#make -j conv_bwd_driver_offline
|
||||
#make -j conv_wrw_driver_offline
|
||||
#make -j gemm_driver_offline
|
||||
|
||||
DRIVER="./host/driver_offline/conv_fwd_driver_offline"
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
#M01=$7
|
||||
#N01=$8
|
||||
|
||||
KBATCH=$7
|
||||
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
######### layout algo verify init log repeat M___ N___ K___
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
|
||||
|
||||
# Resnet50
|
||||
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1
|
||||
|
||||
# 256x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
|
||||
# 128x128x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH
|
||||
|
||||
|
||||
# 128x64x32 c64
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
|
||||
# 64x128x32 c64
|
||||
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH
|
||||
|
||||
# 64x64x32 c32
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448
|
||||
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448
|
||||
Reference in New Issue
Block a user