From 033a4d6cf3d7b4a62c7e215260af831b3a5e19a3 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 19 Jun 2021 13:43:45 -0500 Subject: [PATCH] pass-by-void-pointer for gridwise_dynamic_gemm_v1r2 (#38) * pass-by-void-pointer for gridwise_dynamic_gemm_v1r2 * use pass-by-value by default [ROCm/composable_kernel commit: d2315b0dfcd6f31cca4328819eaf60d77e952dd6] --- .../driver/driver_dynamic_gemm_v1r2.hpp | 131 ++++++++++++++++++ .../gridwise_dynamic_gemm_v1r2.hpp | 58 ++++++++ .../include/utility/config.amd.hpp.in | 4 +- driver/src/conv_driver_v2.cpp | 4 +- 4 files changed, 193 insertions(+), 4 deletions(-) diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp index 527360d6b2..9c63e44961 100644 --- a/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp +++ b/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp @@ -167,6 +167,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; } +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE float ave_time = 0; if(has_main_k_block_loop && has_double_tail_k_block_loop) @@ -279,6 +280,136 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, } return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); + DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); + b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif } } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp index 525f1bcf25..697d5db972 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp @@ -12,6 +12,7 @@ namespace ck { +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE template {}, integral_constant{}); } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by __CONSTANT__ void pointer +// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_k_m0_m1_grid_desc, + const void __CONSTANT__* p_b_k_n0_n1_grid_desc, + const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void __CONSTANT__ void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k_m0_m1_grid_desc = + *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); + const auto b_k_n0_n1_grid_desc = + *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif template