diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp deleted file mode 100644 index 9c63e44961..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp +++ /dev/null @@ -1,416 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2 -#define CK_DRIVER_DYNAMIC_GEMM_V1R2 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v1r2.hpp" - -namespace ck { - -template -__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseDynamicGemm_km_kn_mn_v1r2; - - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting"); - } - - const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); - - { - std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " - << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " - << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); - DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); - b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = - kernel_dynamic_gemm_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp deleted file mode 100644 index be65fd34c8..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp +++ /dev/null @@ -1,416 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3 -#define CK_DRIVER_DYNAMIC_GEMM_v1r3 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v1r3.hpp" - -namespace ck { - -template -__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseDynamicGemm_km_kn_mn_v1r3; - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting"); - } - - const auto a_k0_m0_m1_k1_grid_desc = - GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); - const auto b_k0_n0_n1_k1_grid_desc = - GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); - - using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); - using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); - - { - std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); - DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); - b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = - kernel_dynamic_gemm_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp index 957ca02723..e709f768cb 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -7,9 +7,12 @@ namespace ck { -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X +// GemmM0 = 1 +// GemmM1 = K +// GemmN0 = N0 +// GemmN1 = (N / N0) * Ho * Wo +// GemmK0 = (C / C0) * Y * X +// GemmK1 = C0 template ) { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -136,7 +136,7 @@ struct DynamicPad __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -227,7 +227,7 @@ struct DynamicLeftPad __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -318,7 +318,7 @@ struct DynamicRightPad __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -420,7 +420,7 @@ struct DynamicEmbed __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) const { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && @@ -1096,7 +1096,7 @@ struct DynamicMerge_v2_magic_division typename UpIdx, index_t Hack> __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, - const UpIdxDiff& idx_diff_up, + const UpIdxDiff&, LowIdx& idx_low, const UpIdx& idx_up_new, Number) const @@ -1254,7 +1254,7 @@ struct DynamicMerge_v2r2_magic_division typename UpIdx, index_t Hack> __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, - const UpIdxDiff& idx_diff_up, + const UpIdxDiff&, LowIdx& idx_low, const UpIdx& idx_up_new, Number) const @@ -1383,7 +1383,7 @@ struct DynamicUnMerge __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) const { CalculateLowerIndex(idx_diff_low, idx_diff_up); @@ -1597,7 +1597,7 @@ struct DynamicVectorize __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) const { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -1654,7 +1654,7 @@ struct DynamicSlice __host__ __device__ constexpr DynamicSlice() = default; - __host__ __device__ constexpr DynamicSlice(const LowLength& low_length, + __host__ __device__ constexpr DynamicSlice(const LowLength&, const SliceBegin& slice_begin, const SliceEnd& slice_end) : up_lengths_{make_tuple(slice_end - slice_begin)}, @@ -1687,7 +1687,7 @@ struct DynamicSlice __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdx&, Number) { static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && @@ -1709,8 +1709,7 @@ struct DynamicSlice } template - __host__ __device__ constexpr bool - IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const { return true; } diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp index 9d809576b8..ebb970f481 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp @@ -317,7 +317,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, NewUpperDimensionNewVisibleIdss{}); static_assert(is_valid_sequence_map::value && - is_valid_sequence_map::value, + is_valid_sequence_map::value, "wrong!"); } @@ -395,7 +395,6 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); - constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); MultiIndex idx_hidden; @@ -491,11 +490,8 @@ template ; + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); // this is what needs to be calculated auto idx_diff_hidden = make_zero_multi_index(); diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 8336fea2ae..5e8f898f26 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -236,15 +236,15 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // shift constexpr index_t adaptor0_max_hidden_id = [&]() { - index_t adaptor0_max_hidden_id = NumericLimits::Min(); + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); static_for<0, ndim_low, 1>{}([&](auto idim_low) { - adaptor0_max_hidden_id = - math::max(adaptor0_max_hidden_id, + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); }); @@ -252,17 +252,17 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); static_for<0, ndim_up, 1>{}([&](auto idim_up) { - adaptor0_max_hidden_id = - math::max(adaptor0_max_hidden_id, + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); }); }); - return adaptor0_max_hidden_id; + return adaptor0_max_hidden_id_; }(); constexpr index_t adaptor1_min_hidden_id = [&]() { - index_t adaptor1_min_hidden_id = NumericLimits::Max(); + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = @@ -285,7 +285,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a if(!is_bottom_dim) { - adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id); + adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id); } }); @@ -294,13 +294,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // get the min of all upper dimensions static_for<0, ndim_up, 1>{}([&](auto idim_up) { - adaptor1_min_hidden_id = - math::min(adaptor1_min_hidden_id, + adaptor1_min_hidden_id_ = + math::min(adaptor1_min_hidden_id_, TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); }); }); - return adaptor1_min_hidden_id; + return adaptor1_min_hidden_id_; }(); constexpr index_t adaptor1_hidden_id_shift = @@ -321,11 +321,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // sequence in, sequence out constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { - auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1); + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { - low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift; + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; }); // match hidden id @@ -335,13 +335,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a if constexpr(low_dim_hidden_ids_1[idim_low_1] == TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) { - low_dim_hidden_ids_1_mod(idim_low_1) = + low_dim_hidden_ids_1_mod_(idim_low_1) = TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; } }); }); - return low_dim_hidden_ids_1_mod; + return low_dim_hidden_ids_1_mod_; } (); @@ -367,14 +367,14 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // sequence in, constexpr tuple out constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { - auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1); + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { - up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift; + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; }); - return up_dim_hidden_ids_1_mod; + return up_dim_hidden_ids_1_mod_; } (); diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp index 36de380719..694b2fd2cc 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -14,7 +14,7 @@ namespace ck { // 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate template ::type = false> -struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -140,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); public: - __device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2() : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( get_thread_local_1d_id())}, a_thread_copy_{ @@ -183,7 +183,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); } __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } @@ -207,21 +207,21 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_k_m0_m1_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_k_n0_n1_thread_desc_.GetElementSpaceSize()); constexpr auto threadwise_gemm = - ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, - Sequence<1, M1PerThreadM11>, - Sequence<1, N1PerThreadN11>>{}; + ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; // read A_sub_0 a_thread_copy_.Run(a_k_m0_m1_block_desc_, diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp similarity index 90% rename from composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp rename to composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp index e3ba21494a..6a3885936e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp @@ -1,10 +1,10 @@ -#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP -#define CK_BLOCKWISE_GEMM_V2R3_HPP +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP #include "common_header.hpp" #include "tensor_adaptor.hpp" #include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" -#include "threadwise_contraction.hpp" +#include "threadwise_contraction_dlops.hpp" namespace ck { @@ -21,6 +21,7 @@ namespace ck { // 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time // 2. CThreadBuffer is StaticBuffer // Also assume: +// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 // BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) template + typename BM10BN10ThreadClusterBN10Xs, // Sequence index_t AThreadCopyScalarPerVector_BM11, index_t BThreadCopyScalarPerVector_BN11, typename std::enable_if::type = false> -struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -56,19 +57,17 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - static constexpr index_t BM100 = BM10BN10ThreadClusterBM100; - static constexpr index_t BN100 = BM10BN10ThreadClusterBN100; + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; - static constexpr index_t BM101 = BM10BN10ThreadClusterBM101; - static constexpr index_t BN101 = BM10BN10ThreadClusterBN101; + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; static constexpr index_t BM11 = BM1PerThreadBM11; static constexpr index_t BN11 = BN1PerThreadBN11; - static constexpr index_t BM1 = - BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11; - static constexpr index_t BN1 = - BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11; + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; static constexpr index_t BM0 = BM / BM1; static constexpr index_t BN0 = BN / BN1; @@ -149,7 +148,7 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -170,6 +169,11 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ BBlockDesc_BK0_BN_BK1{}.GetLength(I0), "wrong! K dimension not consistent"); + // TODO remove this restriction + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + // TODO: remove this restriction static_assert(BM0 == 2 && BN0 == 2, "wrong"); } @@ -195,14 +199,14 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); } template - __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc, + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const @@ -216,13 +220,13 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_ CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, "wrong"); - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp similarity index 92% rename from composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp rename to composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp index a4aa355d5a..e624ad0b4d 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -1,8 +1,8 @@ -#ifndef CK_BLOCKWISE_GEMM_V3_HPP -#define CK_BLOCKWISE_GEMM_V3_HPP +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #include "common_header.hpp" -#include "threadwise_gemm_v3.hpp" +#include "threadwise_gemm_dlops_v3.hpp" namespace ck { @@ -19,7 +19,7 @@ template -struct BlockwiseGemm_km_kn_m0m1n0n1_v3 +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 { struct MatrixIndex { @@ -51,7 +51,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ThreadGemmADataPerRead_K, 1>; - __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} { @@ -138,16 +138,17 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 static_assert(WPerThread % WoPerThreadSubC == 0, ""); // thread A buffer for GEMM - StaticBuffer a_thread_buf; + StaticBuffer + a_thread_buf; - constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3{}; + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp deleted file mode 100644 index 7e2f924b58..0000000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp +++ /dev/null @@ -1,514 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_V2_HPP -#define CK_BLOCKWISE_GEMM_V2_HPP - -#include "common_header.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_gemm_v2.hpp" - -namespace ck { - -// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] -// A and B are visable to the whole block, C is distributed among each thread -// Assume: -// 1. A: -// 1. ABlockDesc is known at compile-time -// 2. ABlockBuffer is DynamicBuffer -// 2. B: -// 1. ABlockDesc is known at compile-time -// 2. BBlockBuffer is DynamicBuffer -// 3. C: -// 1. CThreadDesc is known at compile-time -// 2. CThreadBuffer is StaticBuffer -template ::type = false> -struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 -{ - using AIndex = MultiIndex<3>; - using BIndex = MultiIndex<3>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - public: - __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1() - : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, - a_thread_copy_{ - make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, - b_thread_copy_{ - make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() && - CThreadDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 * - M1N1ThreadClusterN11 * M1N1ThreadClusterN10, - "wrong! blocksize and cluster size not consistent"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - } - - __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) - { - constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - // 4-d data space into 4-d thread space - // upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 * - // M1N1ThreadClusterN11} lower: {M0, M1, N0, N1} - constexpr auto adaptor0 = make_single_stage_tensor_adaptor( - make_tuple(make_vectorize_transform(M0, 1), - make_vectorize_transform(M1PerThread, M1 / M1PerThread), - make_vectorize_transform(N0, 1), - make_vectorize_transform(N1PerThread, N1 / N1PerThread)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // thread position 4-d thread space - // upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10, - // M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, - // M1N1ThreadClusterN10 * M1N1ThreadClusterN11} - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple( - make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)), - make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); - - // 4-d thread space to 1-d thread space - // upper: {BlockSize} - // lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10, - // M1N1ThreadClusterN11} - constexpr auto adaptor2 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11))), - make_tuple(Sequence<0, 2, 1, 3>{}), - make_tuple(Sequence<0>{})); - - constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2); - - return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = - make_static_buffer(a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = - make_static_buffer(b_thread_desc_.GetElementSpaceSize()); - - constexpr auto threadwise_gemm = - ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, - Sequence, - Sequence>{}; - - constexpr index_t K = ABlockDesc{}.GetLength(I0); - - static_for<0, K, KPerThread>{}([&](auto k) { - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0), - b_thread_buf); - - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I0, I0)); - }); - } - - private: - static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1); - static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1); - - // A[K, M0, M1] - static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - // B[K, N0, N1] - static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - using AThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - AThreadCopyScalarPerVector_M1, - 1>; - - using BThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - BThreadCopyScalarPerVector_N1, - 1>; - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] -// A and B are visable to the whole block, C is distributed among each thread -// Assume: -// 1. A: -// 1. ABlockDesc is known at compile-time -// 2. ABlockBuffer is DynamicBuffer -// 2. B: -// 1. ABlockDesc is known at compile-time -// 2. BBlockBuffer is DynamicBuffer -// 3. C: -// 1. CThreadDesc is known at compile-time -// 2. CThreadBuffer is StaticBuffer -// Also assume: -// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) -template ::type = false> -struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 -{ - using AIndex = MultiIndex<3>; - using BIndex = MultiIndex<3>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - public: - __device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2() - : c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, - a_thread_copy_{ - make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, - b_thread_copy_{ - make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() && - CThreadDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 * - M1N1ThreadClusterN11 * M1N1ThreadClusterN10, - "wrong! blocksize and cluster size not consistent"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - - // TODO: remove this restriction - static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 && - CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2, - "wrong"); - } - - __device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) - { - constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - // 4-d data space into 4-d thread space - constexpr auto adaptor0 = make_single_stage_tensor_adaptor( - make_tuple(make_vectorize_transform(M0, 1), - make_vectorize_transform(M1PerThread, M1 / M1PerThread), - make_vectorize_transform(N0, 1), - make_vectorize_transform(N1PerThread, N1 / N1PerThread)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // thread position 4-d thread space - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple( - make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)), - make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); - - // 4-d thread space to 1-d thread space - constexpr auto adaptor2 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11))), - make_tuple(Sequence<0, 2, 1, 3>{}), - make_tuple(Sequence<0>{})); - - constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2); - - return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = - make_static_buffer(a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = - make_static_buffer(b_thread_desc_.GetElementSpaceSize()); - - constexpr auto threadwise_gemm = - ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, - Sequence<1, M1PerThread>, - Sequence<1, N1PerThread>>{}; - - constexpr index_t K = ABlockDesc{}.GetLength(I0); - - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I1, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I1, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I0, I0)); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0), - c_thread_buf, - make_tuple(I0, I0, I1, I0)); - - // loop over rest of k - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0), - b_thread_buf, - make_tuple(I0, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I0, I0)); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0), - b_thread_buf, - make_tuple(I0, I1, I0), - c_thread_buf, - make_tuple(I1, I0, I1, I0)); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I1, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I1, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I0, I0)); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0), - c_thread_buf, - make_tuple(I0, I0, I1, I0)); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0), - b_thread_buf, - make_tuple(I0, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I0, I0)); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0), - b_thread_buf, - make_tuple(I0, I1, I0), - c_thread_buf, - make_tuple(I1, I0, I1, I0)); - } - - private: - static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1); - static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1); - - // A[K, M0, M1] - static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - // B[K, N0, N1] - static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{})); - - using AThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - AThreadCopyScalarPerVector_M1, - 1>; - - using BThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - BThreadCopyScalarPerVector_N1, - 1>; - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index f21983d5b5..98407ab7fc 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -138,10 +138,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = - make_static_buffer(a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = - make_static_buffer(b_thread_desc_.GetElementSpaceSize()); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); @@ -358,10 +358,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = - make_static_buffer(a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = - make_static_buffer(b_thread_desc_.GetElementSpaceSize()); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp similarity index 95% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp rename to composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp index f47e85e0bd..6d48a18169 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp @@ -1,11 +1,11 @@ -#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP -#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP +#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #include "common_header.hpp" #include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2r3.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp" @@ -25,7 +25,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_contraction_v1r2( + kernel_dynamic_contraction_dlops_v1r2( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -55,7 +55,7 @@ template -struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 +struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -252,9 +250,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ constexpr auto BN = GN0 * GN11; constexpr auto BM1 = - Number{}; + Number{}; constexpr auto BN1 = - Number{}; + Number{}; constexpr auto BM0 = BM / BM1; constexpr auto BN0 = BN / BN1; @@ -331,11 +331,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ integral_constant, integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); @@ -387,7 +387,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, Sequence, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, @@ -411,7 +411,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ // B matrix blockwise copy auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, Sequence, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, @@ -439,7 +439,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockSize, FloatAB, FloatAB, @@ -449,10 +449,8 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ BM1PerThreadBM11, BN1PerThreadBN11, BK0PerThread, - BM10BN10ThreadClusterBM100, - BM10BN10ThreadClusterBN100, - BM10BN10ThreadClusterBM101, - BM10BN10ThreadClusterBN101, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, BM1PerThreadBM11, BN1PerThreadBN11>{}; @@ -474,7 +472,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_ FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); ThreadwiseDynamicTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto a_block_odd_buf = make_dynamic_buffer( + auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( + auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp similarity index 91% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp rename to composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp index 697d5db972..7a4ef1d7ea 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp @@ -1,11 +1,11 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP #include "common_header.hpp" #include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2r2.hpp" +#include "blockwise_gemm_dlops_v2r2.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp" @@ -26,7 +26,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_v1r2( + kernel_dynamic_gemm_dlops_v1r2( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -52,8 +52,8 @@ __global__ void integral_constant{}); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by __CONSTANT__ void pointer -// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to // non-modifiable parameter address space, so compiler can enable corresponding optimization template -struct GridwiseDynamicGemm_km_kn_mn_v1r2 +struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -326,11 +326,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 integral_constant, integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); const auto K = a_k_m0_m1_grid_desc.GetLength(I0); @@ -373,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1, @@ -399,7 +399,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 // B matrix blockwise copy auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1, @@ -429,21 +429,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); @@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); ThreadwiseDynamicTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); - auto a_block_odd_buf = - make_dynamic_buffer(p_a_block_double + a_block_aligned_space_size, - a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = - make_dynamic_buffer(p_b_block_double + b_block_aligned_space_size, - b_k_n0_n1_block_desc.GetElementSpaceSize()); + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k_n0_n1_block_desc.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp similarity index 91% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp rename to composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp index 20f91140db..db3cb99121 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp @@ -5,7 +5,7 @@ #include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2r3.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp" @@ -26,7 +26,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_v1r3( + kernel_dynamic_gemm_dlops_v1r3( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -52,8 +52,8 @@ __global__ void integral_constant{}); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by __CONSTANT__ void pointer -// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to // non-modifiable parameter address space, so compiler can enable corresponding optimization template -struct GridwiseDynamicGemm_km_kn_mn_v1r3 +struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -277,9 +275,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 const auto N0 = N / N1; constexpr auto M11 = - Number{}; + Number{}; constexpr auto N11 = - Number{}; + Number{}; constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; @@ -333,11 +333,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 integral_constant, integral_constant) { - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); // divide block work by [M, N] @@ -383,7 +383,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, Sequence, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, @@ -407,7 +407,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 // B matrix blockwise copy auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, Sequence, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, @@ -435,7 +435,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockSize, FloatAB, FloatAB, @@ -445,15 +445,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 M1PerThreadM111, N1PerThreadN111, KPerThread, - M11N11ThreadClusterM1100, - M11N11ThreadClusterN1100, - M11N11ThreadClusterM1101, - M11N11ThreadClusterN1101, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, M1PerThreadM111, N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); constexpr auto c_m10_m11_n10_n11_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( @@ -470,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output - auto c_thread_buf = make_static_buffer( + auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); ThreadwiseDynamicTensorSliceSet_v1( + auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( + auto b_block_even_buf = make_dynamic_buffer( p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); - auto a_block_odd_buf = - make_dynamic_buffer(p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = - make_dynamic_buffer(p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { @@ -610,10 +608,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 // output: register to global memory { - constexpr index_t M11 = - M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; - constexpr index_t N11 = - N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; constexpr index_t M10 = MPerBlockM1 / M11; constexpr index_t N10 = NPerBlockN1 / N11; @@ -631,7 +631,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 Number{})); const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = - blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); ThreadwiseDynamicTensorSliceTransfer_v1r3< FloatAcc, diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp similarity index 92% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp rename to composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp index 2c30787ef4..34dea34833 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp @@ -7,7 +7,7 @@ #include "dynamic_tensor_descriptor_helper.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "blockwise_gemm_v3.hpp" +#include "blockwise_gemm_dlops_v3.hpp" namespace ck { @@ -15,7 +15,7 @@ template -struct GridwiseDynamicGemm_km_kn_mn_v3 +struct GridwiseDynamicGemmDlops_km_kn_mn_v3 { __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -84,11 +84,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - const auto a_global_buf = make_dynamic_buffer( + const auto a_global_buf = make_dynamic_buffer( p_a_global, a_e_k_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( + const auto b_global_buf = make_dynamic_buffer( p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( + auto c_global_buf = make_dynamic_buffer( p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); constexpr auto E = EPerBlock * 3 * 3; @@ -100,7 +100,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); - // divide block work by [M, N] +// divide block work by [M, N] #if 0 const auto k_block_work_num = K / Number{}; const auto ho_block_work_num = Ho / Number{}; @@ -152,19 +152,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( Number{}, Number<1>{}, Number{}, Number{})); - auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3{}; + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); @@ -184,7 +185,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadClusterLengths_E_K, @@ -225,11 +226,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 true>(b_e_n_ho_wo_global_desc, make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); - auto a_block_buf = make_dynamic_buffer(p_shared_block, - a_e_k_desc.GetElementSpaceSize()); + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e_k_desc.GetElementSpaceSize()); // register allocation for output - StaticBuffer + StaticBuffer c_thread_buf; // initialize output thread tensor @@ -252,7 +255,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 BGlobalMoveSliceWindowIteratorHacks{}; // double regsiter buffer for b - StaticBuffer + StaticBuffer b_thread_even_buf, b_thread_odd_buf; // LDS double buffer: preload data diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp index 3b1dc9cea1..a5b1de79a7 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp @@ -61,10 +61,10 @@ __global__ void kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_k0_m_k1_grid_desc, - const void __CONSTANT__* p_b_k0_n_k1_grid_desc, - const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, - const void __CONSTANT__* p_c_block_cluster_adaptor) + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -95,7 +95,7 @@ template {}; constexpr auto I3 = Number<3>{}; - const auto a_grid_buf = make_dynamic_buffer( + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); @@ -312,7 +312,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, @@ -339,7 +339,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // B matrix blockwise copy auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, @@ -413,7 +413,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number{}, Number{})); - StaticBuffer, c_mr_nr_blk_desc.GetElementSpaceSize()> c_thread_buf; @@ -442,9 +442,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = BGridMoveSliceWindowIteratorHacks{}; - auto a_block_buf = make_dynamic_buffer( + auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( + auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); // preload data into LDS @@ -515,7 +515,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Number{}, Number<1>{})); - StaticBuffer + StaticBuffer c_blk_buf_; static_for<0, MRepeat, 1>{}([&](auto mr_i) { @@ -585,7 +585,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); - StaticBuffer c_blk_buf_; + StaticBuffer c_blk_buf_; // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp similarity index 95% rename from composable_kernel/include/tensor_operation/threadwise_contraction.hpp rename to composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp index 995c871c5e..0440bc0312 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -1,5 +1,5 @@ -#ifndef CK_THREADWISE_CONTRACTION_HPP -#define CK_THREADWISE_CONTRACTION_HPP +#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP +#define CK_THREADWISE_CONTRACTION_DLOPS_HPP #include "common_header.hpp" #include "math.hpp" @@ -25,9 +25,9 @@ template ::type = false> -struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -71,8 +71,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; constexpr auto TK = TKLengths{}[I0]; constexpr auto TM0 = TMLengths{}[I0]; @@ -131,9 +129,9 @@ template ::type = false> -struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -177,8 +175,6 @@ struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_T constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; constexpr index_t TK0 = TKLengths{}[I0]; constexpr index_t TK1 = TKLengths{}[I1]; diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp index 4253431d5e..9626113686 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -54,7 +54,7 @@ template ::type = false> @@ -159,9 +159,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 static_ford{}([&](auto ordered_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_idx[I0]; @@ -170,10 +170,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index @@ -186,10 +186,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; }); - auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; }(); typename vector_type_maker::type dst_vector; @@ -217,17 +215,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -295,9 +293,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_lengths[I0] - 1; @@ -306,10 +304,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index after last iteration in Run(), if it has not being reset by @@ -321,19 +319,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; }); - auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * - dst_scalar_per_access; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; }(); // constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step; + Index reset_dst_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - return reset_dst_data_step; + return reset_dst_data_step_; }(); return reset_dst_data_step; @@ -478,9 +474,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 static_ford{}([&](auto ordered_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_idx[I0]; @@ -489,10 +485,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index @@ -505,10 +501,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; }); - auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; }(); typename vector_type_maker::type src_vector; @@ -534,17 +528,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -612,9 +606,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_access_lengths[I0] - 1; @@ -623,10 +617,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index after last iteration in Run(), if it has not being reset by @@ -638,19 +632,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; }); - auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * - src_scalar_per_access; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; }(); // constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step; + Index reset_src_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - return reset_src_data_step; + return reset_src_data_step_; }(); return reset_src_data_step; @@ -682,7 +674,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template >, @@ -797,9 +789,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; @@ -808,10 +800,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index @@ -824,11 +816,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ordered_src_access_idx[i]; }); - auto src_data_idx = - container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; }(); vector_type_maker_t src_tmp_vector; @@ -852,18 +841,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= + move_on_dim_(i) &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -900,8 +889,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 DstBuffer& dst_buf, const DstIteratorHacks& dst_iterator_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or - DstBuffer::GetAddressSpace() == AddressSpace::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); static_assert(is_same>, @@ -962,9 +951,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 static_ford{}([&](auto ordered_dst_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; @@ -973,10 +962,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index @@ -989,11 +978,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ordered_dst_access_idx[i]; }); - auto dst_data_idx = - container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; }(); vector_type_maker_t dst_tmp_vector; @@ -1019,18 +1005,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= + move_on_dim_(i) &= ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -1108,9 +1094,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; @@ -1119,10 +1105,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index after last iteration in RunRead(), if it has not being reset by @@ -1134,19 +1120,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; }); - auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; }(); // constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step; + Index reset_src_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - return reset_src_data_step; + return reset_src_data_step_; }(); return reset_src_data_step; @@ -1170,9 +1154,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; @@ -1181,10 +1165,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index after last iteration in RunWrite(), if it has not being reset by @@ -1196,19 +1180,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; }); - auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; }(); // constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step; + Index reset_dst_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - return reset_dst_data_step; + return reset_dst_data_step_; }(); return reset_dst_data_step; @@ -1270,7 +1252,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; @@ -1357,9 +1339,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access of each dim constexpr auto src_scalar_per_access = generate_sequence_v2( [&](auto i) constexpr { diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp index 331c9fc201..ba60e26c38 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp @@ -13,7 +13,7 @@ namespace ck { // 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 4. Use thread buffer template >, @@ -140,9 +140,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_idx[I0]; @@ -151,10 +151,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index @@ -167,11 +167,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ordered_src_access_idx[i]; }); - auto src_data_idx = - container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_vector_tensor_lengths; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; }(); vector_type_maker_t src_vector; @@ -201,18 +198,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= + move_on_dim_(i) &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -249,8 +246,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 DstBuffer& dst_buf, const DstIteratorHacks& dst_iterator_hacks) { - static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or - DstBuffer::GetAddressSpace() == AddressSpace::Lds, + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); static_assert(is_same>, @@ -316,9 +313,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 static_ford{}([&](auto ordered_dst_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_idx[I0]; @@ -327,10 +324,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index @@ -343,11 +340,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ordered_dst_access_idx[i]; }); - auto dst_data_idx = - container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_vector_tensor_lengths; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; }(); vector_type_maker_t dst_vector; @@ -379,18 +373,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim; + StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; static_for{}([&](auto j) { - move_on_dim(i) &= + move_on_dim_(i) &= ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; }); }); - return move_on_dim; + return move_on_dim_; } (); @@ -463,9 +457,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_src_access_lengths[I0] - 1; @@ -474,10 +468,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate src data index after last iteration in RunRead(), if it has not being reset by @@ -489,19 +483,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; }); - auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_vector_tensor_lengths; - - return src_data_idx; + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; }(); // constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step; + Index reset_src_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - return reset_src_data_step; + return reset_src_data_step_; }(); return reset_src_data_step; @@ -520,9 +512,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 // judge move forward or move backward during the last iteration constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep; + StaticallyIndexedArray forward_sweep_; - forward_sweep(I0) = true; + forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { index_t tmp = ordered_dst_access_lengths[I0] - 1; @@ -531,10 +523,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; }); - forward_sweep(i) = tmp % 2 == 0; + forward_sweep_(i) = tmp % 2 == 0; }); - return forward_sweep; + return forward_sweep_; }(); // calculate dst data index after last iteration in RunWrite(), if it has not being reset by @@ -546,19 +538,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; }); - auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_vector_tensor_lengths; - - return dst_data_idx; + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; }(); // constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step; + Index reset_dst_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - return reset_dst_data_step; + return reset_dst_data_step_; }(); return reset_dst_data_step; @@ -620,7 +610,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp similarity index 98% rename from composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp rename to composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp index 8c78448e80..f31150c2cf 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ -#ifndef CK_THREADWISE_GEMM_V3_HPP -#define CK_THREADWISE_GEMM_V3_HPP +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP #include "common_header.hpp" #include "math.hpp" @@ -22,7 +22,7 @@ template ::type = false> -struct ThreadwiseGemm_km_kn_mn_v3 +struct ThreadwiseGemmDlops_km_kn_mn_v3 { template __device__ typename vector_type::type @@ -220,31 +220,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { if constexpr(N == 1) { - return __llvm_amdgcn_raw_buffer_load_fp32( + return llvm_amdgcn_raw_buffer_load_fp32( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 2) { - return __llvm_amdgcn_raw_buffer_load_fp32x2( + return llvm_amdgcn_raw_buffer_load_fp32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 4) { - return __llvm_amdgcn_raw_buffer_load_fp32x4( + return llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); tmp.AsType()(Number<1>{}) = - __llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - 0); + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); return tmp.AsType()(Number<0>{}); } @@ -253,17 +253,17 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { if constexpr(N == 1) { - return __llvm_amdgcn_raw_buffer_load_fp16( + return llvm_amdgcn_raw_buffer_load_fp16( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 2) { - return __llvm_amdgcn_raw_buffer_load_fp16x2( + return llvm_amdgcn_raw_buffer_load_fp16x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 4) { - return __llvm_amdgcn_raw_buffer_load_fp16x4( + return llvm_amdgcn_raw_buffer_load_fp16x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 8) @@ -271,18 +271,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, #if 0 vector_type tmp; - tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); tmp.AsType()(Number<1>{}) = - __llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(half_t), 0); return tmp.AsType()(Number<0>{}); #else - float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); @@ -293,31 +293,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { if constexpr(N == 1) { - return __llvm_amdgcn_raw_buffer_load_i32( + return llvm_amdgcn_raw_buffer_load_i32( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 2) { - return __llvm_amdgcn_raw_buffer_load_i32x2( + return llvm_amdgcn_raw_buffer_load_i32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 4) { - return __llvm_amdgcn_raw_buffer_load_i32x4( + return llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); tmp.AsType()(Number<1>{}) = - __llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - 0); + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + 0); return tmp.AsType()(Number<0>{}); } } @@ -325,16 +325,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, { if constexpr(N == 1) { - return __llvm_amdgcn_raw_buffer_load_i8( + return llvm_amdgcn_raw_buffer_load_i8( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); } else if constexpr(N == 2) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return __llvm_amdgcn_raw_buffer_load_i8x2( + return llvm_amdgcn_raw_buffer_load_i8x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); #else - int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16( + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); @@ -343,10 +343,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, else if constexpr(N == 4) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return __llvm_amdgcn_raw_buffer_load_i8x4( + return llvm_amdgcn_raw_buffer_load_i8x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); #else - int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32( + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); @@ -357,18 +357,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); tmp.AsType()(Number<1>{}) = - __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); return tmp.AsType()(Number<0>{}); #else - int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2( + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); @@ -379,30 +379,30 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4( + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); tmp.AsType()(Number<1>{}) = - __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); tmp.AsType()(Number<2>{}) = - __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int8_t), - 0); + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + 0); tmp.AsType()(Number<3>{}) = - __llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int8_t), - 0); + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + 0); return tmp.AsType()(Number<0>{}); #else - int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4( + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); @@ -428,61 +428,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type { if constexpr(N == 1) { - __llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - __llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - __llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - __llvm_amdgcn_raw_buffer_store_i32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - __llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - __llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - __llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -490,94 +436,148 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type } else if constexpr(N == 2) { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - __llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); -#else - __llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); #endif } else if constexpr(N == 4) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - __llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); #else - __llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); #endif } else if constexpr(N == 8) { - __llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); } else if constexpr(N == 16) { - __llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); } } else if constexpr(is_same::value) { if constexpr(N == 1) { - __llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } - else if constexpr(N == 2) - { - __llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } else if constexpr(N == 4) { - __llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); } else if constexpr(N == 8) { vector_type tmp{src_thread_data}; - __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - __llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - 0); + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); } } } diff --git a/composable_kernel/include/utility/amd_dlop.hpp b/composable_kernel/include/utility/amd_dlop.hpp index e5b9d901ba..8ce19012e9 100644 --- a/composable_kernel/include/utility/amd_dlop.hpp +++ b/composable_kernel/include/utility/amd_dlop.hpp @@ -1,7 +1,7 @@ #ifndef CK_AMD_DLOP_HPP #define CK_AMD_DLOP_HPP -#include "float_type.hpp" +#include "data_type.hpp" namespace ck { diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index c4b167a128..ce80fc0549 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -1,7 +1,7 @@ #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP -#include "float_type.hpp" +#include "data_type.hpp" namespace ck { diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp index 8981db7a7b..841d48f81c 100644 --- a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp +++ b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp @@ -1,11 +1,11 @@ #ifndef CK_AMD_LLVM_INTRINSIC_HPP #define CK_AMD_LLVM_INTRINSIC_HPP -#include "float_type.hpp" +#include "data_type.hpp" namespace ck { -__device__ int32_t __llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); +__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); } // namespace ck #endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index b373e27be3..da74fe1d48 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -1,7 +1,7 @@ #ifndef CK_AMD_XDLOPS_HPP #define CK_AMD_XDLOPS_HPP -#include "float_type.hpp" +#include "data_type.hpp" namespace ck { diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index ad38d0461c..5ff7688a1c 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -7,8 +7,9 @@ #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" #include "multi_index.hpp" +#include "data_type_enum.hpp" #include "data_type.hpp" -#include "float_type.hpp" +#include "data_type_helper.hpp" #include "functional.hpp" #include "functional2.hpp" #include "functional3.hpp" diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index ea13e2b6f8..4908d8d818 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -8,18 +8,13 @@ #include "bfloat16_dev.hpp" // address space for kernel parameter -#define __CONSTANT__ __attribute__((address_space(4))) +#define CONSTANT __attribute__((address_space(4))) -// device backend -#define CK_DEVICE_BACKEND_AMD 1 - -// GPU ID -#if 0 -#define CK_AMD_GPU_GFX906 1 -#elif 1 -#define CK_AMD_GPU_GFX908 1 -#elif 0 -#define CK_AMD_GPU_GFX1030 1 +// GPU target +// should enable one and only one GPU target +#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) +#error Need to define a single GPU target #endif // HIP version @@ -36,7 +31,8 @@ #endif // buffer resourse -#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) +#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(CK_AMD_GPU_GFX1030) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -50,10 +46,6 @@ #define CK_USE_AMD_INLINE_ASM 1 #endif -#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM -#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 -#endif - // AMD DLOPS #ifndef CK_USE_AMD_DLOP #define CK_USE_AMD_DLOP 1 @@ -78,14 +70,6 @@ #define CK_USE_AMD_XDLOPS 0 #endif -#ifndef CK_USE_AMD_XDLOPS_INLINE_ASM -#define CK_USE_AMD_XDLOPS_INLINE_ASM 0 -#endif - -#ifndef CK_USE_AMD_XDLOPS_EMULATE -#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes -#endif - // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -104,18 +88,6 @@ #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 #endif -#ifndef CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE -#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1 -#endif - -#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK -#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0 -#endif - -#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK -#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0 -#endif - // pass tensor descriptor by value or void* #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 @@ -131,17 +103,6 @@ #define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #endif -// workaround: put all workaround here -// workaround for unnecessary VGPR <--> AGPR data movement when using mfma LLVM intrinsic -#ifndef CK_WORKAROUND_SWDEV_229564 -#define CK_WORKAROUND_SWDEV_229564 1 -#endif - -// workaround for accvgpr over-allocation -#ifndef CK_WORKAROUND_SWDEV_241664 -#define CK_WORKAROUND_SWDEV_241664 1 -#endif - // workaround for compiler crash when compiling recursive lambda #ifndef CK_WORKAROUND_SWDEV_275126 #define CK_WORKAROUND_SWDEV_275126 1 @@ -159,7 +120,7 @@ namespace ck { -enum AddressSpace +enum AddressSpaceEnum_t { Generic, Global, @@ -168,7 +129,7 @@ enum AddressSpace Vgpr }; -enum InMemoryDataOperation +enum InMemoryDataOperationEnum_t { Set, AtomicAdd diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index 66d2a88be4..24a2190e84 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -1,8 +1,1001 @@ -#ifndef CK_DATA_TYPE_HPP -#define CK_DATA_TYPE_HPP +#ifndef CK_FLOAT_TYPE_AMD_HPP +#define CK_FLOAT_TYPE_AMD_HPP + +#include "statically_indexed_array.hpp" namespace ck { +using half_t = _Float16; + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +// scalar_type +template +struct scalar_type; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = ushort; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +// +template +struct vector_type +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } +}; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; +using ushort16_t = typename vector_type::type; +using ushort32_t = typename vector_type::type; +using ushort64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// data type conversion +template +struct type_convert +{ + template + __device__ T operator()(X x) const + { + return static_cast(x); + } +}; + +template <> +template <> +__device__ float type_convert::operator()(ushort x) const +{ + return bfloat16_to_float(x); +} + +template <> +template <> +__device__ ushort type_convert::operator()(float x) const +{ + return float_to_bfloat16(x); +} + +// TODO: deprecate this +template +struct inner_product_with_conversion +{ + static constexpr auto convert = type_convert(); + + template + __device__ T operator()(typename vector_type::type a, + typename vector_type::type b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, N, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); + + return acc; + } + + __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + + __device__ T operator()(int8x4_t a, int8x4_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 4, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x8_t a, int8x8_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 8, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x16_t a, int8x16_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 16, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } +}; + template struct NumericLimits; diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/composable_kernel/include/utility/data_type_enum.hpp new file mode 100644 index 0000000000..fba380a5fc --- /dev/null +++ b/composable_kernel/include/utility/data_type_enum.hpp @@ -0,0 +1,19 @@ +#ifndef CK_DATA_TYPE_ENUM_HPP +#define CK_DATA_TYPE_ENUM_HPP + +namespace ck { + +// this enumerate should be synchronized with include/miopen.h +typedef enum { + Half = 0, + Float = 1, + Int32 = 2, + Int8 = 3, + Int8x4 = 4, + BFloat16 = 5, + Double = 6, + Unknown = 100, +} DataTypeEnum_t; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type_helper.hpp b/composable_kernel/include/utility/data_type_helper.hpp new file mode 100644 index 0000000000..6a234cd10b --- /dev/null +++ b/composable_kernel/include/utility/data_type_helper.hpp @@ -0,0 +1,76 @@ +#ifndef CK_DATA_TYPE_HELPER_HPP +#define CK_DATA_TYPE_HELPER_HPP + +#include "data_type.hpp" +#include "data_type_enum.hpp" + +namespace ck { + +template +struct get_datatype_from_enum; + +template <> +struct get_datatype_from_enum +{ + using type = int8_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = int32_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = half_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = float; +}; + +template <> +struct get_datatype_from_enum +{ + using type = double; +}; + +template +struct get_datatype_enum_from_type; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 073577b93f..5f5f386306 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -5,7 +5,7 @@ namespace ck { #include "amd_buffer_addressing_v2.hpp" -template +template struct DynamicBuffer { using type = T; @@ -18,7 +18,7 @@ struct DynamicBuffer { } - __host__ __device__ static constexpr AddressSpace GetAddressSpace() + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } @@ -32,7 +32,7 @@ struct DynamicBuffer is_same>>::type, typename scalar_type>>::type>::value, bool>::type = false> - __host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const { // X contains multiple T constexpr index_t scalar_per_t_vector = @@ -46,7 +46,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - if constexpr(GetAddressSpace() == AddressSpace::Global) + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING return amd_buffer_load_v2>, t_per_x>( @@ -80,7 +80,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - if constexpr(GetAddressSpace() == AddressSpace::Global) + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING amd_buffer_store_v2>, t_per_x>( @@ -92,14 +92,15 @@ struct DynamicBuffer } #endif } - else if constexpr(GetAddressSpace() == AddressSpace::Lds) + else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) { if(is_valid_offset) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE *reinterpret_cast(&p_data_[i]) = x; #else - // HACK: compiler would lower IR "store address_space(3)" into inefficient + // HACK: compiler would lower IR "store address_space(3)" into + // inefficient // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix @@ -119,7 +120,8 @@ struct DynamicBuffer is_same>, int8x8_t>::value) || (is_same>, int8x16_t>::value && is_same>, int8x16_t>::value), - "wrong! not implemented for this combination, please add implementation"); + "wrong! not implemented for this combination, please add " + "implementation"); if constexpr(is_same>, int8_t>::value && is_same>, int8_t>::value) @@ -194,7 +196,7 @@ struct DynamicBuffer __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; -template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) diff --git a/composable_kernel/include/utility/float_type.hpp b/composable_kernel/include/utility/float_type.hpp deleted file mode 100644 index f41bd6db23..0000000000 --- a/composable_kernel/include/utility/float_type.hpp +++ /dev/null @@ -1,999 +0,0 @@ -#ifndef CK_FLOAT_TYPE_AMD_HPP -#define CK_FLOAT_TYPE_AMD_HPP - -#include "statically_indexed_array.hpp" - -namespace ck { - -using half_t = _Float16; - -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; - -template -using vector_type_maker_t = typename vector_type_maker::type; - -template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - -// scalar_type -template -struct scalar_type; - -template -struct scalar_type -{ - using type = T; - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - -// -template <> -struct scalar_type -{ - using type = float; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = half_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = ushort; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = int32_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = int8_t; - static constexpr index_t vector_size = 1; -}; - -// -template -struct vector_type -{ - using d1_t = T; - using type = d1_t; - - union - { - T d1_; - StaticallyIndexedArray d1x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, "wrong!"); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value, "wrong!"); - - return data_.d1x1_; - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; - - union - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - - using type = d32_t; - - union - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - - using type = d64_t; - - union - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - - using type = d128_t; - - union - { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - } -}; - -template -struct vector_type -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - - using type = d256_t; - - union - { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "wrong!"); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - } -}; - -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; - -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; - -// bfp16 -using ushort2_t = typename vector_type::type; -using ushort4_t = typename vector_type::type; -using ushort8_t = typename vector_type::type; -using ushort16_t = typename vector_type::type; -using ushort32_t = typename vector_type::type; -using ushort64_t = typename vector_type::type; - -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; - -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; - -// data type conversion -template -struct type_convert -{ - template - __device__ T operator()(X x) const - { - return static_cast(x); - } -}; - -template <> -template <> -__device__ float type_convert::operator()(ushort x) const -{ - return bfloat16_to_float(x); -} - -template <> -template <> -__device__ ushort type_convert::operator()(float x) const -{ - return float_to_bfloat16(x); -} - -template -struct inner_product_with_conversion -{ - static constexpr auto convert = type_convert(); - - template - __device__ T operator()(typename vector_type::type a, - typename vector_type::type b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, N, 1>{}([&](auto i) { - acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); - }); - - return acc; - } - - __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } - - __device__ T operator()(int8x4_t a, int8x4_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 4, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; - } - - __device__ T operator()(int8x8_t a, int8x8_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 8, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; - } - - __device__ T operator()(int8x16_t a, int8x16_t b) const - { - const vector_type a_vector{a}; - const vector_type b_vector{b}; - - T acc = 0; - - static_for<0, 16, 1>{}([&](auto i) { - acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); - }); - - return acc; - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index 3239205fda..b7489016e9 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -127,7 +127,8 @@ struct MagicDivision DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { uint32_t dividend_u32 = as_type(dividend_i32); - uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32; + uint32_t tmp = + (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; return (tmp + dividend_u32) >> shift; } #else diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 368a955ab3..e451059647 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -150,7 +150,15 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { - if(x == y || x == 0) + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) { return y; } @@ -160,11 +168,11 @@ __host__ __device__ constexpr index_t gcd(index_t x, index_t y) } else if(x > y) { - return gcd(x - y, y); + return gcd(x % y, y); } else { - return gcd(x, y - x); + return gcd(x, y % x); } } @@ -181,7 +189,7 @@ template = 2, bool>::type = false> __host__ __device__ constexpr auto gcd(X x, Ys... ys) { - return gcd(x, ys...); + return gcd(x, gcd(ys...)); } // least common multiple diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp index 95fd08e880..a23cf4f80d 100644 --- a/composable_kernel/include/utility/static_buffer.hpp +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -5,7 +5,7 @@ namespace ck { -template +template struct StaticBuffer : public StaticallyIndexedArray { using type = T; @@ -13,7 +13,7 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ constexpr StaticBuffer() : base{} {} - __host__ __device__ static constexpr AddressSpace GetAddressSpace() + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } @@ -23,7 +23,9 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } }; -template +template __host__ __device__ constexpr auto make_static_buffer(Number) { return StaticBuffer{}; diff --git a/composable_kernel/include/utility/synchronization.hpp b/composable_kernel/include/utility/synchronization.hpp index 4e899baa95..da74f2074d 100644 --- a/composable_kernel/include/utility/synchronization.hpp +++ b/composable_kernel/include/utility/synchronization.hpp @@ -5,8 +5,6 @@ namespace ck { -__device__ void __llvm_amdgcn_s_barrier() __asm("llvm.amdgcn.s.barrier"); - __device__ void block_sync_lds() { #if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM @@ -15,11 +13,9 @@ __device__ void block_sync_lds() s_barrier \ " ::); #else - __llvm_amdgcn_s_barrier(); + __syncthreads(); #endif } -__device__ void block_sync_lds_vmem() { __llvm_amdgcn_s_barrier(); } - } // namespace ck #endif diff --git a/composable_kernel/include/utility/type_helper.hpp b/composable_kernel/include/utility/type_helper.hpp deleted file mode 100644 index 987f07e3f4..0000000000 --- a/composable_kernel/include/utility/type_helper.hpp +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef CK_TYPE_HELPER_HPP -#define CK_TYPE_HELPER_HPP - -#include "float_type.hpp" - -namespace ck { - -template -struct get_type_from_type_id -{ - using type = float; -}; - -template <> -struct get_type_from_type_id<'H'> -{ - using type = half_t; -}; - -template <> -struct get_type_from_type_id<'F'> -{ - using type = float; -}; - -template <> -struct get_type_from_type_id<'D'> -{ - using type = double; -}; - -} // namespace ck - -#endif diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp similarity index 67% rename from composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp rename to composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp index 8dc473ec3f..652ccdb926 100644 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp @@ -1,15 +1,18 @@ #include "common_header.hpp" -#include "type_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v1r2.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" using namespace ck; -using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; -using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; -using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; constexpr index_t BlockSize = CK_PARAM_BlockSize; @@ -61,7 +64,8 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); -extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_prepare( +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( int n, int c, int hi, @@ -147,48 +151,48 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_k using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using GridwiseGemm = - GridwiseDynamicGemm_km_kn_mn_v1r2; + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); @@ -212,14 +216,14 @@ extern "C" __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( + dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_k_m0_m1_grid_desc, - const void __CONSTANT__* p_b_k_n0_n1_grid_desc, - const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -283,48 +287,48 @@ extern "C" __global__ void using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; using GridwiseGemm = - GridwiseDynamicGemm_km_kn_mn_v1r2; + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; constexpr auto a_k_m0_m1_grid_desc_tmp = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp index b9a835336b..d33bc74aa6 100644 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -1,5 +1,4 @@ #include "common_header.hpp" -#include "type_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" @@ -7,9 +6,13 @@ using namespace ck; -using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; -using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; -using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; constexpr index_t BlockSize = CK_PARAM_BlockSize; @@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AK0MK1GridDesc, BK0NK1GridDesc, CMNGridDesc, @@ -213,10 +216,10 @@ extern "C" __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_k0_m_k1_grid_desc, - const void __CONSTANT__* p_b_k0_n_k1_grid_desc, - const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, - const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) { constexpr auto I0 = Number<0>{}; @@ -286,7 +289,7 @@ extern "C" __global__ void FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AK0MK1GridDesc, BK0NK1GridDesc, CMNGridDesc, diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp index 9e8de0ac8e..d49693b511 100644 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -1,5 +1,4 @@ #include "common_header.hpp" -#include "type_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" @@ -7,9 +6,13 @@ using namespace ck; -using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; -using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; -using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; constexpr index_t BlockSize = CK_PARAM_BlockSize; @@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AK0MK1GridDesc, BK0NK1GridDesc, CMNGridDesc, @@ -213,10 +216,10 @@ extern "C" __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_k0_m_k1_grid_desc, - const void __CONSTANT__* p_b_k0_n_k1_grid_desc, - const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc, - const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) { constexpr auto I0 = Number<0>{}; @@ -287,7 +290,7 @@ extern "C" __global__ void FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AK0MK1GridDesc, BK0NK1GridDesc, CMNGridDesc, diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp similarity index 77% rename from composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp rename to composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index 93a3bb39a0..90c957bb0b 100644 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -1,31 +1,34 @@ #include "common_header.hpp" -#include "type_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_v1r2.hpp" +#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" using namespace ck; -using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; -using FloatAcc = typename get_type_from_type_id(CK_PARAM_ACC_DATATYPE)>::type; -using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; constexpr index_t BlockSize = CK_PARAM_BlockSize; constexpr auto GN0 = Number{}; constexpr auto GK1 = Number{}; -constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; -constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; -constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; -constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; -constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; -constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; -constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100; -constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100; -constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101; -constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101; +constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; +constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; +constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; + +constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; +constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; +constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; + +using BM10BN10ThreadClusterBM10Xs = Sequence; +using BM10BN10ThreadClusterBN10Xs = Sequence; using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence; @@ -55,29 +58,26 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2> constexpr index_t CThreadTransferSrcDstVectorDim = 5; constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBlockLoop); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); -extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare( - index_t N, - index_t C, - index_t Hi, - index_t Wi, - index_t K, - index_t Y, - index_t X, - index_t ConvStrideH, - index_t ConvStrideW, - index_t ConvDilationH, - index_t ConvDilationW, - index_t InLeftPadH, - index_t InLeftPadW, - index_t InRightPadH, - index_t InRightPadW, - void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, - void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1, - void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - void* p_c_grid_block_cluster_blockid_to_gm10_gn10) +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, + index_t C, + index_t Hi, + index_t Wi, + index_t K, + index_t Y, + index_t X, + index_t ConvStrideH, + index_t ConvStrideW, + index_t ConvDilationH, + index_t ConvDilationW, + index_t InLeftPadH, + index_t InLeftPadW, + index_t InRightPadH, + index_t InRightPadW, + void* p_desc_tuple) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = - GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AGridDesc_GK0_GM0_GM1_GK1, BGridDesc_GK0_GN0_GN1_GK1, CGridDesc_GM0_GM1_GN0_GN1, @@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k BM1PerThreadBM11, BN1PerThreadBN11, BK0PerThread, - BM10BN10ThreadClusterBM100, - BM10BN10ThreadClusterBN100, - BM10BN10ThreadClusterBM101, - BM10BN10ThreadClusterBN101, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterArrangeOrder, @@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k AGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks>; - auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = - GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); - auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = - GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); - auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - c_grid_desc_gm0_gm1_gn0_gn1); - auto c_grid_block_cluster_blockid_to_gm10_gn10 = - GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( - c_grid_desc_gm0_gm1_gn0_gn1); - - if(hipThreadIdx_x == 0) + if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) { - *static_cast( - p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1; - *static_cast( - p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1; - *static_cast( - p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; - *static_cast( - p_c_grid_block_cluster_blockid_to_gm10_gn10) = - c_grid_block_cluster_blockid_to_gm10_gn10; - }; + auto desc_tuple = + make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1), + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1), + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1), + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + *static_cast(p_desc_tuple) = desc_tuple; + } }; extern "C" __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( + dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, - const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1, - const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10) + const void CONSTANT* p_desc_tuple) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; constexpr auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); @@ -316,12 +303,12 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = - GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, AGridDesc_GK0_GM0_GM1_GK1, BGridDesc_GK0_GN0_GN1_GK1, CGridDesc_GM0_GM1_GN0_GN1, @@ -331,10 +318,8 @@ extern "C" __global__ void BM1PerThreadBM11, BN1PerThreadBN11, BK0PerThread, - BM10BN10ThreadClusterBM100, - BM10BN10ThreadClusterBN100, - BM10BN10ThreadClusterBM101, - BM10BN10ThreadClusterBN101, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterArrangeOrder, @@ -371,18 +356,23 @@ extern "C" __global__ void decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( c_grid_desc_gm0_gm1_gn0_gn1)); - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = - *reinterpret_cast( - (const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1); - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = - *reinterpret_cast( - (const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1); - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - *reinterpret_cast( - (const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = - *reinterpret_cast( - (const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10); + using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{}, + BGridDesc_GK0_GN0_GN10_GN11_GK1{}, + CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, + CGridBlockCluster_BlockId_To_GM10_GN10{})); + + const auto desc_tuple = *reinterpret_cast( +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // TODO: how to cast? + (const void*)p_desc_tuple +#pragma clang diagnostic pop + ); + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2]; + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3]; constexpr index_t shared_block_size = GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/host/driver_offline/conv_fwd_driver_offline.cpp b/host/driver_offline/conv_fwd_driver_offline.cpp index 405d6e7c40..ef2e16c4fa 100644 --- a/host/driver_offline/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/conv_fwd_driver_offline.cpp @@ -12,17 +12,17 @@ #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V4R4R2_NHWC 0 -#define USE_CONV_FWD_V6R1_NCHW 0 +#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_CONV_FWD_V6R1_NCHW 1 #define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 @@ -301,19 +301,20 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -327,9 +328,9 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( tmp[I0], tmp[I1], tmp[I2], @@ -354,19 +355,20 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -380,20 +382,21 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 0ea190611b..49e0223b33 100644 --- a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -264,7 +264,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 315f201458..ce4dd155f6 100644 --- a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -236,7 +236,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp similarity index 97% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index 845095b947..24ba775309 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_gemm_v1r2.hpp" +#include "driver_dynamic_gemm_dlops_v1r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( +void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -142,12 +142,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_v1r2< + float ave_time = driver_dynamic_gemm_dlops_v1r2< BlockSize, TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_gemmk_gemmm_grid_desc), decltype(in_gemmk_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 5890b12e00..b6b1cc8969 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -220,7 +220,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(descs[I0]), decltype(descs[I1]), decltype(descs[I2]), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp similarity index 92% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp index ea0ada9f88..cdd1084c0d 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_v1r3.hpp" +#include "driver_dynamic_gemm_dlops_v1r3.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -56,7 +56,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( const auto out_n_ho_wo_k_desc = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); -#if 0 +#if 1 // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 // cdata = 64, BlockSize = 256 constexpr index_t BlockSize = 256; @@ -70,10 +70,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; @@ -102,10 +100,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; @@ -134,10 +130,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( constexpr index_t GemmN1PerThreadN111 = 4; constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; @@ -211,12 +205,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_v1r3< + float ave_time = driver_dynamic_gemm_dlops_v1r3< BlockSize, TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), @@ -226,10 +220,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( GemmM1PerThreadM111, GemmN1PerThreadN111, GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, + GemmM11N11ThreadClusterM110Xs, + GemmM11N11ThreadClusterN110Xs, GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 035546d31a..b56cbc0335 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp index bb37ac309f..10284b48f3 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -165,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp index c1e63664e5..f2a30fb525 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -229,7 +229,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 0455f77718..601878c347 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -288,7 +288,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmm_gemmn_grid_desc), diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp similarity index 95% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index cb2e7e5264..ca0d47c33a 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -1,8 +1,8 @@ #include #include "device.hpp" #include "host_tensor.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( +void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -145,9 +145,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( constexpr auto conv_driver = #if 0 - DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad #else - DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad #endif ::type, diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp similarity index 93% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp index 0b45350234..8fb276b464 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -1,8 +1,9 @@ +#pragma once #include #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_contraction_v1r2.hpp" +#include "driver_dynamic_contraction_dlops_v1r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( +void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -66,10 +67,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( constexpr index_t BN1PerThreadBN11 = 4; constexpr index_t BK0PerThread = 1; - constexpr index_t BM10BN10ThreadClusterBM100 = 8; - constexpr index_t BM10BN10ThreadClusterBN100 = 8; - constexpr index_t BM10BN10ThreadClusterBM101 = 2; - constexpr index_t BM10BN10ThreadClusterBN101 = 2; + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; @@ -100,10 +99,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( constexpr index_t BN1PerThreadBN11 = 4; constexpr index_t BK0PerThread = 1; - constexpr index_t BM10BN10ThreadClusterBM100 = 8; - constexpr index_t BM10BN10ThreadClusterBN100 = 8; - constexpr index_t BM10BN10ThreadClusterBM101 = 2; - constexpr index_t BM10BN10ThreadClusterBN101 = 2; + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; @@ -183,12 +180,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_contraction_v1r2< + float ave_time = driver_dynamic_contraction_dlops_v1r2< BlockSize, TInWei, TAcc, TOut, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_grid_desc_gk0_gm0_gm1_gk1), decltype(in_grid_desc_gk0_gn0_gn1_gk1), decltype(out_grid_desc_gm0_gm1_gn0_gn1), @@ -198,10 +195,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( BM1PerThreadBM11, BN1PerThreadBN11, BK0PerThread, - BM10BN10ThreadClusterBM100, - BM10BN10ThreadClusterBN100, - BM10BN10ThreadClusterBM101, - BM10BN10ThreadClusterBN101, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder diff --git a/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp similarity index 85% rename from composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp rename to host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp index 2f68fec7e3..2f175962c1 100644 --- a/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp +++ b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp @@ -1,31 +1,27 @@ -#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP -#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP +#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_v1r2.hpp" +#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" -namespace ck { - -template __host__ float -driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - index_t nrepeat) +driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) { + using namespace ck; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -72,7 +70,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, // GEMM using GridwiseContraction = - GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, @@ -87,10 +85,8 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, BM1PerThreadBM11, BN1PerThreadBN11, BK0PerThread, - BM10BN10ThreadClusterBM100, - BM10BN10ThreadClusterBN100, - BM10BN10ThreadClusterBM101, - BM10BN10ThreadClusterBN101, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterArrangeOrder, @@ -182,7 +178,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, if(has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r2< + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -209,7 +205,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r2< + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -236,7 +232,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r2< + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -263,7 +259,7 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, } else { - const auto kernel = kernel_dynamic_contraction_v1r2< + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -291,6 +287,4 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, return ave_time; } - -} // namespace ck #endif diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp similarity index 92% rename from composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index 5ad7c0ca93..7c4b1043f3 100644 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -1,33 +1,31 @@ -#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP -#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v2.hpp" +#include "gridwise_dynamic_gemm_dlops_v2.hpp" #include "gridwise_operation_wrapper.hpp" -namespace ck { - -template -struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad + ck::index_t ABlockTransferSrcScalarPerVector_E, + ck::index_t ABlockTransferDstScalarPerVector_K, + ck::index_t BThreadTransferSrcScalarPerVector_W, + ck::index_t CThreadTransferDstScalarPerVector_W> +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad { template - __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad const FloatAB* __restrict__ p_in_global, FloatC* __restrict__ p_out_global) const { + using namespace ck; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -169,12 +169,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad #if 1 // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< + using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< BlockSize, FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_e_k_global_desc), decltype(in_e_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc), @@ -349,5 +349,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad #endif } }; -} // namespace ck #endif diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp similarity index 92% rename from composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp rename to host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp index f7c24ead4d..b7f8e6039c 100644 --- a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -1,33 +1,31 @@ -#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP -#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP +#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP +#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v2.hpp" +#include "gridwise_dynamic_gemm_dlops_v2.hpp" #include "gridwise_operation_wrapper.hpp" -namespace ck { - -template -struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad + ck::index_t ABlockTransferSrcScalarPerVector_E, + ck::index_t ABlockTransferDstScalarPerVector_K, + ck::index_t BThreadTransferSrcScalarPerVector_W, + ck::index_t CThreadTransferDstScalarPerVector_W> +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad { template - __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -47,6 +45,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad const FloatAB* __restrict__ p_in_global, FloatC* __restrict__ p_out_global) const { + using namespace ck; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -181,12 +181,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad Sequence<0, 0, 0, 0, 0>{})); // GEMM - using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< + using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< BlockSize, FloatAB, FloatAcc, FloatC, - InMemoryDataOperation::Set, + InMemoryDataOperationEnum_t::Set, decltype(wei_e_k_global_desc), decltype(in_e_n_ho_wo_global_desc), decltype(out_k_n_hop_wop_global_desc), @@ -364,5 +364,4 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad } } }; -} // namespace ck #endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..0ebc68b48a --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp @@ -0,0 +1,415 @@ +#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 +#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" + +template +__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; + + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting"); + } + + const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " + << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " + << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); + DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); + b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..d075eac822 --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp @@ -0,0 +1,411 @@ +#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 +#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r3.hpp" + +template +__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r3; + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting"); + } + + const auto a_k0_m0_m1_k1_grid_desc = + GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); + + using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); + using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + { + std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); + DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); + b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp similarity index 83% rename from composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp rename to host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp index f07a51d21d..481d08188d 100644 --- a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp @@ -1,48 +1,46 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 -#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 +#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 +#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" -namespace ck { - -template {}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -176,23 +176,21 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - float ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + float ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); #endif return ave_time; } - -} // namespace ck #endif diff --git a/host/driver_online/conv_fwd_driver_online.cpp b/host/driver_online/conv_fwd_driver_online.cpp index 3b25f5d039..c91f76fa24 100644 --- a/host/driver_online/conv_fwd_driver_online.cpp +++ b/host/driver_online/conv_fwd_driver_online.cpp @@ -14,8 +14,8 @@ #include "device_tensor.hpp" #include "handle.hpp" #include "hipCheck.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" #include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" @@ -35,6 +35,7 @@ enum ConvForwardAlgo int main(int argc, char* argv[]) { using namespace ck; + using namespace ck_driver; using size_t = std::size_t; hipStream_t stream; @@ -93,7 +94,7 @@ int main(int argc, char* argv[]) using in_data_t = float; using acc_data_t = float; using out_data_t = float; -#elif 1 +#elif 0 using in_data_t = half_t; using acc_data_t = float; using out_data_t = half_t; @@ -225,25 +226,25 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable = - &default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; + tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable = + &default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw; - online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( - handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); } #endif @@ -257,24 +258,105 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - const auto tunable = tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw{}; +#if 1 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 1, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 1}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 1}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#elif 0 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 2, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 2}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 2}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#elif 1 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 4, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 4}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 4}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#endif - online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( - handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); + online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + compile_param, + nrepeat); } #endif @@ -355,13 +437,15 @@ int main(int argc, char* argv[]) check_error(out_host, out_device); +#if 0 if(do_log) { - LogRange(std::cout << "in : ", in.mData, ",") << std::endl; - LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; } +#endif } delete handle; diff --git a/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b0c4921019 --- /dev/null +++ b/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,673 @@ +#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP + +#include + +namespace ck_driver { + +struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + ck::DataTypeEnum_t ABDataTypeEnum; + ck::DataTypeEnum_t AccDataTypeEnum; + ck::DataTypeEnum_t CDataTypeEnum; + + int BlockSize; + + int GN0; + int GK1; + + int GM1PerBlockGM11; + int GN1PerBlockGN11; + int GK0PerBlock; + + int BM1PerThreadBM11; + int BN1PerThreadBN11; + int BK0PerThread; + + std::array BM10BN10ThreadClusterBM10Xs; + std::array BM10BN10ThreadClusterBN10Xs; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + int CThreadTransferDstScalarPerVector; + + bool HasMainKBlockLoop; + bool HasDoubleTailKBlockLoop; + + auto GetCompileParameterString() const + { + // clang-format off + return + " -DCK_PARAM_ABDataTypeEnum=" + + std::to_string(ABDataTypeEnum) + + " -DCK_PARAM_AccDataTypeEnum=" + + std::to_string(AccDataTypeEnum) + + " -DCK_PARAM_CDataTypeEnum=" + + std::to_string(CDataTypeEnum) + + " -DCK_PARAM_BlockSize=" + + std::to_string(BlockSize) + + " -DCK_PARAM_GN0=" + + std::to_string(GN0) + + " -DCK_PARAM_GK1=" + + std::to_string(GK1) + + " -DCK_PARAM_GM1PerBlockGM11=" + + std::to_string(GM1PerBlockGM11) + + " -DCK_PARAM_GN1PerBlockGN11=" + + std::to_string(GN1PerBlockGN11) + + " -DCK_PARAM_GK0PerBlock=" + + std::to_string(GK0PerBlock) + + " -DCK_PARAM_BM1PerThreadBM11=" + + std::to_string(BM1PerThreadBM11) + + " -DCK_PARAM_BN1PerThreadBN11=" + + std::to_string(BN1PerThreadBN11) + + " -DCK_PARAM_BK0PerThread=" + + std::to_string(BK0PerThread) + + " -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" + + std::to_string(BM10BN10ThreadClusterBM10Xs[0]) + "," + + std::to_string(BM10BN10ThreadClusterBM10Xs[1]) + + " -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" + + std::to_string(BM10BN10ThreadClusterBN10Xs[0]) + "," + + std::to_string(BM10BN10ThreadClusterBN10Xs[1]) + + " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(CThreadTransferDstScalarPerVector) + + " -DCK_PARAM_HasMainKBlockLoop=" + + std::to_string(HasMainKBlockLoop) + + " -DCK_PARAM_HasDoubleTailKBlockLoop=" + + std::to_string(HasDoubleTailKBlockLoop); + // clang-format on + } +}; + +struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + ck::DataTypeEnum_t ABDataTypeEnum; + ck::DataTypeEnum_t CDataTypeEnum; + + int BlockSize; + + int GN0; + int GK1; + + int GM1PerBlockGM11; + int GN1PerBlockGN11; + int GK0PerBlock; + + int BM1PerThreadBM11; + int BN1PerThreadBN11; + int BK0PerThread; + + std::array BM10BN10ThreadClusterBM10Xs; + std::array BM10BN10ThreadClusterBN10Xs; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; +}; + +inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw() +{ + constexpr auto f32 = ck::DataTypeEnum_t::Float; + constexpr auto f16 = ck::DataTypeEnum_t::Half; + constexpr auto i8 = ck::DataTypeEnum_t::Int8; + + return std::vector{ + // clang-format off + // fp32 + {f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // fp16 + {f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // i8 + { i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}} + // clang-format on + }; +} + +// TODO make this common interface and write specs for it +struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + static auto + CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc, + const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable) + { + using namespace ck; + + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum; + const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum; + + DataTypeEnum_t AccDataTypeEnum; + + switch(ABDataTypeEnum) + { + case DataTypeEnum_t::Float: + case DataTypeEnum_t::Half: AccDataTypeEnum = DataTypeEnum_t::Float; break; + case DataTypeEnum_t::Int8: AccDataTypeEnum = DataTypeEnum_t::Int32; break; + default: return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + const int BlockSize = tunable.BlockSize; + + const int GN0 = tunable.GN0; + const int GK1 = tunable.GK1; + + const int GM11 = tunable.GM1PerBlockGM11; + const int GN11 = tunable.GN1PerBlockGN11; + const int GK0PerBlock = tunable.GK0PerBlock; + + const int BM11 = tunable.BM1PerThreadBM11; + const int BN11 = tunable.BN1PerThreadBN11; + const int BK0PerThread = tunable.BK0PerThread; + + const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs; + const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs; + + const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo); + + const int C0 = GK1; + + if(!(C % C0 == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const int C1 = C / C0; + + const int GK0 = C1 * Y * X; + + if(!(GK0 % GK0PerBlock == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1); + + const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0); + + return std::make_tuple( + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{ + ABDataTypeEnum, + AccDataTypeEnum, + CDataTypeEnum, + BlockSize, + GN0, + GK1, + GM11, + GN11, + GK0PerBlock, + BM11, + BN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + CThreadTransferDstScalarPerVector, + HasMainKBlockLoop, + HasDoubleTailKBlockLoop}, + true); + } + + static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc) + { + for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()) + { + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param; + bool found = false; + + std::tie(compile_param, found) = + CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable); + + if(found && IsValidCompileParameter(conv_problem_desc, compile_param)) + return std::make_tuple(compile_param, true); + } + + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc) + { + bool found = false; + + std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc); + + return found; + } + + static bool + IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + using namespace ck; + + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int GK1 = compile_param.GK1; + const int GN0 = compile_param.GN0; + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int BM11 = compile_param.BM1PerThreadBM11; + const int BN11 = compile_param.BN1PerThreadBN11; + + const int C0 = GK1; + const int N0 = GN0; + + if(!(C % C0 == 0)) + return false; + + const int C1 = C / C0; + + if(!(N % N0 == 0)) + return false; + + const int N1 = N / N0; + + const int GM0 = 1; + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + const int GK0 = C1 * Y * X; + + // check data type + { + if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum && + conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum)) + return false; + + if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float || + compile_param.ABDataTypeEnum == DataTypeEnum_t::Half) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float)) + return false; + } + else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32)) + return false; + } + } + + // check gridwise contraction + { + if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0)) + return false; + + const bool has_main_k_block_loop = + ((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1); + + const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0); + + if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop && + has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop)) + return false; + } + + // check A blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GM0, 1, GM11, GK1}; + const auto& cluster_lengths = + compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto& thread_slice_lengths = + compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto& src_vector_lengths = + compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto& dst_vector_lengths = + compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0)) + return false; + + if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization, GK0 is global mem vector dim + if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 && + src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Dst vectorization, {GM11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GM11, GK1} + if(!(GM11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + + if(!(dst_vector_lengths[3] == 1)) + return false; + } + } + + // check B blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GN0, 1, GN11, GK1}; + const auto& cluster_lengths = + compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto& thread_slice_lengths = + compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto& src_vector_lengths = + compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto& dst_vector_lengths = + compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 && + thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization: {GN11} is global mem vector dim + if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 && + src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Src tensor layout related vectorization + if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 && + conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 && + conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!((Ho * Wo) % src_vector_lengths[3] == 0)) + return false; + } + else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!(Wo % src_vector_lengths[3] == 0)) + return false; + } + else + { + if(!(src_vector_lengths[3] == 1)) + return false; + } + + // check Dst vectorization: {GN11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GN11, GK1} + if(!(GN11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(dst_vector_lengths[3] == 1)) + return false; + + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + } + } + + // check blockwise GEMM + { + const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(), + compile_param.BM10BN10ThreadClusterBM10Xs.end(), + 1, + std::multiplies{}); + + const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(), + compile_param.BM10BN10ThreadClusterBN10Xs.end(), + 1, + std::multiplies{}); + + if(!(compile_param.BlockSize == BM10 * BN10)) + return false; + + const int BM = GM0 * GM11; + const int BN = GN0 * GN11; + + const int BM1 = BM10 * BM11; + const int BN1 = BN10 * BN11; + + if(!(BM % BM1 == 0 && BN % BN1 == 0)) + return false; + + const int BM0 = BM / BM1; + const int BN0 = BN / BN1; + + // blockwise GEMM currently only support BM0 == 2 && BN0 == 2 + if(!(BM0 == 2 && BN0 == 2)) + return false; + + if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0)) + return false; + } + + // check C threadwise copy + { + // {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector; + + // check slice length vs Dst vector length: + if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0)) + return false; + + // check Dst memory layout related vectorization: + if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0)) + return false; + } + + return true; + }; + + static int GetBlockSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + return compile_param.BlockSize; + } + + static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int N0 = compile_param.GN0; + const int N1 = N / N0; + + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int GM10 = GM1 / GM11; + const int GN10 = GN1 / GN11; + + return GM10 * GN10; + } + + static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&) + { + // workspace is used for save transformed tensor descritpors created by prepare kernel + return 4096L; + } + + static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; } + + static auto GetTunableList() + { + return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw(); + } +}; + +} // namespace ck_driver +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..58fe588ad9 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,51 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int M1PerThread; + int N1PerThread; + int KPerThread; + + int M1N1ThreadClusterM10; + int M1N1ThreadClusterN10; + int M1N1ThreadClusterM11; + int M1N1ThreadClusterN11; + + std::array ABlockTransferThreadSliceLengths_K_M0_M1; + std::array ABlockTransferThreadClusterLengths_K_M0_M1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_M1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K_N0_N1; + std::array BBlockTransferThreadClusterLengths_K_N0_N1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_N1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = { + 256, 128, 128, 8, 4, 4, 1, + 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, + {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, + {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, + 5, 1}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 05ee9846b8..0000000000 --- a/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw -{ - int32_t BlockSize; - - int32_t MPerBlock; - int32_t NPerBlock; - int32_t KPerBlock; - - int32_t M1PerThread; - int32_t N1PerThread; - int32_t KPerThread; - - int32_t M1N1ThreadClusterM10; - int32_t M1N1ThreadClusterN10; - int32_t M1N1ThreadClusterM11; - int32_t M1N1ThreadClusterN11; - - std::array ABlockTransferThreadSliceLengths_K_M0_M1; - std::array ABlockTransferThreadClusterLengths_K_M0_M1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int32_t ABlockTransferSrcVectorDim; - int32_t ABlockTransferSrcScalarPerVector; - int32_t ABlockTransferDstScalarPerVector_M1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K_N0_N1; - std::array BBlockTransferThreadClusterLengths_K_N0_N1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int32_t BBlockTransferSrcVectorDim; - int32_t BBlockTransferSrcScalarPerVector; - int32_t BBlockTransferDstScalarPerVector_N1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - int32_t CThreadTransferSrcDstVectorDim; - int32_t CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = { - 256, 128, 128, 8, 4, 4, 1, - 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, - {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, - {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, - 5, 1}; -#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 7681438d95..97ce326346 100644 --- a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -3,40 +3,40 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { - int32_t BlockSize; + int BlockSize; - int32_t MPerBlock; - int32_t NPerBlock; - int32_t KPerBlock; + int MPerBlock; + int NPerBlock; + int KPerBlock; - int32_t MPerWave; - int32_t NPerWave; - int32_t K1; + int MPerWave; + int NPerWave; + int K1; - int32_t MRepeat; - int32_t NRepeat; + int MRepeat; + int NRepeat; - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int32_t ABlockTransferSrcVectorDim; - int32_t ABlockTransferSrcScalarPerVector; - int32_t ABlockTransferDstScalarPerVector_K1; + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; bool AThreadTransferSrcResetCoordinateAfterRun; - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int32_t BBlockTransferSrcVectorDim; - int32_t BBlockTransferSrcScalarPerVector; - int32_t BBlockTransferDstScalarPerVector_K1; + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; bool BThreadTransferSrcResetCoordinateAfterRun; - std::array CThreadTransferSrcDstAccessOrder; - int32_t CThreadTransferSrcDstVectorDim; - int32_t CThreadTransferDstScalarPerVector; + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; }; static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp index a4fd8095c4..263c21a13b 100644 --- a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -3,40 +3,40 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { - int32_t BlockSize; + int BlockSize; - int32_t MPerBlock; - int32_t NPerBlock; - int32_t KPerBlock; + int MPerBlock; + int NPerBlock; + int KPerBlock; - int32_t MPerWave; - int32_t NPerWave; - int32_t K1; + int MPerWave; + int NPerWave; + int K1; - int32_t MRepeat; - int32_t NRepeat; + int MRepeat; + int NRepeat; - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - int32_t ABlockTransferSrcVectorDim; - int32_t ABlockTransferSrcScalarPerVector; - int32_t ABlockTransferDstScalarPerVector_K1; + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; bool AThreadTransferSrcResetCoordinateAfterRun; - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - int32_t BBlockTransferSrcVectorDim; - int32_t BBlockTransferSrcScalarPerVector; - int32_t BBlockTransferDstScalarPerVector_K1; + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; bool BThreadTransferSrcResetCoordinateAfterRun; - std::array CThreadTransferSrcDstAccessOrder; - int32_t CThreadTransferSrcDstVectorDim; - int32_t CThreadTransferDstScalarPerVector; + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; }; static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk diff --git a/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index f307e22f53..0000000000 --- a/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP -#define CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP - -struct tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw -{ - int32_t BlockSize = 256; - - int32_t GN0 = 4; - int32_t GK1 = 1; - - int32_t GM1PerBlockGM11 = 128; - int32_t GN1PerBlockGN11 = 32; - int32_t GK0PerBlock = 8; - - int32_t BM1PerThreadBM11 = 4; - int32_t BN1PerThreadBN11 = 4; - int32_t BK0PerThread = 1; - - int32_t BM10BN10ThreadClusterBM100 = 2; - int32_t BM10BN10ThreadClusterBN100 = 2; - int32_t BM10BN10ThreadClusterBM101 = 8; - int32_t BM10BN10ThreadClusterBN101 = 8; - - std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {4, 1, 1, 1, 1}; - std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = { - 2, 1, 1, 128, 1}; - std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - 4, 1, 1, 1, 1}; - std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { - 1, 1, 1, 1, 1}; - - std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {1, 4, 1, 1, 1}; - std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = { - 8, 1, 1, 32, 1}; - std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - 1, 1, 1, 1, 1}; - std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { - 1, 1, 1, 1, 1}; - - int32_t CThreadTransferDstScalarPerVector = 1; -}; -#endif diff --git a/host/driver_online/include/convolution_problem_descriptor.hpp b/host/driver_online/include/convolution_problem_descriptor.hpp new file mode 100644 index 0000000000..df9c110e70 --- /dev/null +++ b/host/driver_online/include/convolution_problem_descriptor.hpp @@ -0,0 +1,79 @@ +#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR +#define CONVOLUTION_PROBLEM_DESCRIPTOR + +namespace ck_driver { + +struct ConvolutionProblemDescriptor +{ + ConvolutionProblemDescriptor() = default; + + ConvolutionProblemDescriptor(int N_, + int K_, + int C_, + int Y_, + int X_, + int Hi_, + int Wi_, + int Ho_, + int Wo_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, + ck::DataTypeEnum_t InDataTypeEnum_, + ck::DataTypeEnum_t WeiDataTypeEnum_, + ck::DataTypeEnum_t OutDataTypeEnum_) + : N{N_}, + K{K_}, + C{C_}, + Y{Y_}, + X{X_}, + Hi{Hi_}, + Wi{Wi_}, + Ho{Ho_}, + Wo{Wo_}, + ConvStrideH{ConvStrideH_}, + ConvStrideW{ConvStrideW_}, + ConvDilationH{ConvDilationH_}, + ConvDilationW{ConvDilationW_}, + InLeftPadH{InLeftPadH_}, + InLeftPadW{InLeftPadW_}, + InRightPadH{InRightPadH_}, + InRightPadW{InRightPadW_}, + InDataTypeEnum{InDataTypeEnum_}, + WeiDataTypeEnum{WeiDataTypeEnum_}, + OutDataTypeEnum{OutDataTypeEnum_} + { + } + + int N; + int K; + int C; + int Y; + int X; + int Hi; + int Wi; + int Ho; + int Wo; + int ConvStrideH; + int ConvStrideW; + int ConvDilationH; + int ConvDilationW; + int InLeftPadH; + int InLeftPadW; + int InRightPadH; + int InRightPadW; + + ck::DataTypeEnum_t InDataTypeEnum; + ck::DataTypeEnum_t WeiDataTypeEnum; + ck::DataTypeEnum_t OutDataTypeEnum; + + std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } +}; + +} // namespace ck_driver +#endif diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp similarity index 92% rename from host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index f852c4dc6f..628bb6d96d 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -1,3 +1,4 @@ +#pragma once #include "device.hpp" #include "host_tensor.hpp" #include "handle.hpp" @@ -5,24 +6,26 @@ #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp" +#include "conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp" namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw { template static std::string get_network_config_string_from_types() { + using namespace ck; + std::string out; - out += static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()); + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); return (out); }; static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* pt) +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) { std::string out("TUN_"); @@ -95,17 +98,20 @@ get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_nchw_kcyx template static std::string get_definition_string_from_types() { + using namespace ck; + std::string out; - out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); return (out); }; static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* pt) +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) { std::string out; @@ -209,7 +215,7 @@ template -void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( olCompile::Handle* handle, const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, @@ -221,10 +227,11 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw const Tensor& in_n_c_hi_wi, const Tensor& wei_k_c_y_x, Tensor& out_n_k_ho_wo, - const tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable, + const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable, ck::index_t nrepeat) { using namespace ck; + using namespace ck_driver; using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; using size_t = std::size_t; @@ -288,8 +295,9 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; - std::string program_name = "dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v4r4_nchw"; + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_dlops_nchw"; std::string param = " -std=c++17 "; std::string network_config; @@ -311,7 +319,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw KernelTimer timer1, timer2; std::string kernel_name; - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_prepare"; + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare"; auto network_config_1 = network_config + "_1"; timer1.Start(); @@ -337,7 +345,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); timer1.End(); - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"; + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw"; auto network_config_2 = network_config + "_2"; timer2.Start(); @@ -356,8 +364,14 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw } { - auto ave_time1 = Driver::get_effective_average(kernel1_times); - auto ave_time2 = Driver::get_effective_average(kernel2_times); + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); const auto N = in_n_c_hi_wi_lengths[I0]; const auto C = in_n_c_hi_wi_lengths[I1]; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 703f8592b8..1e213b92e1 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -11,11 +11,13 @@ namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { template static std::string get_network_config_string_from_types() { + using namespace ck; + std::string out; - out += static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()); + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); return (out); }; @@ -93,11 +95,14 @@ get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nc template static std::string get_definition_string_from_types() { + using namespace ck; + std::string out; - out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); return (out); }; @@ -222,6 +227,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ck::index_t nrepeat) { using namespace ck; + using namespace ck_driver; using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; using size_t = std::size_t; @@ -349,8 +355,14 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc } { - auto ave_time1 = Driver::get_effective_average(kernel1_times); - auto ave_time2 = Driver::get_effective_average(kernel2_times); + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); const auto N = in_n_c_hi_wi_lengths[I0]; const auto C = in_n_c_hi_wi_lengths[I1]; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp index 2f4787d350..8eed1a9934 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -12,11 +12,13 @@ namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { template static std::string get_network_config_string_from_types() { + using namespace ck; + std::string out; - out += static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()); + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); return (out); }; @@ -94,11 +96,14 @@ get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nh template static std::string get_definition_string_from_types() { + using namespace ck; + std::string out; - out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); return (out); }; @@ -302,15 +307,16 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky std::vector kernel1_times; std::vector kernel2_times; - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); for(index_t i = 0; i < nrepeat; ++i) { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( static_cast(in_n_hi_wi_c_lengths[I0]), static_cast(in_n_hi_wi_c_lengths[I1]), @@ -331,15 +337,12 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky b_k0_n_k1_grid_desc_dev_buf, c_m0_m1_m2_n_grid_desc_dev_buf, c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); - } - timer1.End(); + timer1.End(); - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk"; - auto network_config_2 = network_config + "_2"; + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk"; + auto network_config_2 = network_config + "_2"; - timer2.Start(); - for(index_t i = 0; i < nrepeat; ++i) - { + timer2.Start(); handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( reinterpret_cast(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()), reinterpret_cast(wei_k_y_x_c_dev_buf.GetDeviceBuffer()), @@ -348,12 +351,21 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_ky (const void*)(b_k0_n_k1_grid_desc_dev_buf), (const void*)(c_m0_m1_m2_n_grid_desc_dev_buf), (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); } - timer2.End(); { - auto ave_time1 = timer1.GetElapsedTime() / nrepeat; - auto ave_time2 = timer2.GetElapsedTime() / nrepeat; + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); const auto N = in_n_hi_wi_c_lengths[I0]; const auto C = in_n_hi_wi_c_lengths[I3]; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..260c94ee0e --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,182 @@ +#pragma once +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "convolution_problem_descriptor.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp" + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace ck_driver; + using size_t = std::size_t; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + ConvolutionProblemDescriptor conv_problem_desc{in_n_c_hi_wi_lengths[I0], + out_n_k_ho_wo_lengths[I1], + in_n_c_hi_wi_lengths[I1], + wei_k_c_y_x_lengths[I2], + wei_k_c_y_x_lengths[I3], + in_n_c_hi_wi_lengths[I2], + in_n_c_hi_wi_lengths[I3], + out_n_k_ho_wo_lengths[I2], + out_n_k_ho_wo_lengths[I3], + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value}; + + if(!ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::IsValidCompileParameter(conv_problem_desc, + compile_param)) + { + throw std::runtime_error("wrong! IsValidCompileParameter fail"); + } + + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // workspace is used for save transformed tensor descritpors created by prepare kernel + DeviceMem workspace_dev_buf( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetWorkSpaceSize(conv_problem_desc, compile_param)); + + const auto block_size = std::size_t( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetBlockSize(conv_problem_desc, compile_param)); + + const auto grid_size = std::size_t( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetGridSize(conv_problem_desc, compile_param)); + + const std::vector vld1 = {1, 1, 1}; + const std::vector vgd1 = {1, 1, 1}; + + const std::vector vld2 = {static_cast(block_size), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * block_size), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw"; + + std::string compile_param_string = " -std=c++17 " + compile_param.GetCompileParameterString(); + std::string network_config = compile_param_string; + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, + network_config_1, + program_name, + kernel_name, + vld1, + vgd1, + compile_param_string)(static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + (void*)(workspace_dev_buf.GetDeviceBuffer())); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, + network_config_2, + program_name, + kernel_name, + vld2, + vgd2, + compile_param_string)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(workspace_dev_buf.GetDeviceBuffer())); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); + + float perf = (float)(conv_problem_desc.CalculateFlop()) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 2ee2680f5c..0000000000 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,425 +0,0 @@ -#include "device.hpp" -#include "host_tensor.hpp" -#include "handle.hpp" -#include "online_driver_common.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp" - -namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw { - -template -static std::string get_network_config_string_from_types() -{ - std::string out("DAT_"); - - out += static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()); - - return (out); -}; - -static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable) -{ - std::string out("TUN_"); - - out += std::to_string(tunable.BlockSize) + "_"; - - out += std::to_string(tunable.GN0) + "x" + std::to_string(tunable.GK1) + "_"; - - out += std::to_string(tunable.GM1PerBlockGM11) + "x" + std::to_string(tunable.GN1PerBlockGN11) + - "x" + std::to_string(tunable.GK0PerBlock) + "_"; - - out += std::to_string(tunable.BM1PerThreadBM11) + "x" + - std::to_string(tunable.BN1PerThreadBN11) + "x" + std::to_string(tunable.BK0PerThread) + - "_"; - - out += std::to_string(tunable.BM10BN10ThreadClusterBM100) + "x" + - std::to_string(tunable.BM10BN10ThreadClusterBN100) + "x" + - std::to_string(tunable.BM10BN10ThreadClusterBM101) + "x" + - std::to_string(tunable.BM10BN10ThreadClusterBN101) + "_"; - - out += std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_"; - - out += - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_"; - - out += std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + - "x" + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + - "x" + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + - "x" + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + - "x" + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + - "_"; - - out += std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + - "x" + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + - "x" + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + - "x" + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + - "x" + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + - "_"; - - out += std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_"; - - out += - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_"; - - out += std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + - "x" + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + - "x" + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + - "x" + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + - "x" + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + - "_"; - - out += std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + - "x" + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + - "x" + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + - "x" + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + - "x" + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + - "_"; - - out += std::to_string(tunable.CThreadTransferDstScalarPerVector); - - return (out); -}; - -template -static std::string get_definition_string_from_types() -{ - std::string out; - - out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_ACC_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); - - return (out); -}; - -static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable) -{ - std::string out; - - out += " -DCK_PARAM_BlockSize=" + std::to_string(tunable.BlockSize); - - out += " -DCK_PARAM_GN0=" + std::to_string(tunable.GN0); - out += " -DCK_PARAM_GK1=" + std::to_string(tunable.GK1); - - out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(tunable.GM1PerBlockGM11) + - " -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(tunable.GN1PerBlockGN11) + - " -DCK_PARAM_GK0PerBlock=" + std::to_string(tunable.GK0PerBlock); - - out += " -DCK_PARAM_BM1PerThreadBM11=" + std::to_string(tunable.BM1PerThreadBM11) + - " -DCK_PARAM_BN1PerThreadBN11=" + std::to_string(tunable.BN1PerThreadBN11) + - " -DCK_PARAM_BK0PerThread=" + std::to_string(tunable.BK0PerThread); - - out += " -DCK_PARAM_BM10BN10ThreadClusterBM100=" + - std::to_string(tunable.BM10BN10ThreadClusterBM100) + - " -DCK_PARAM_BM10BN10ThreadClusterBN100=" + - std::to_string(tunable.BM10BN10ThreadClusterBN100) + - " -DCK_PARAM_BM10BN10ThreadClusterBM101=" + - std::to_string(tunable.BM10BN10ThreadClusterBM101) + - " -DCK_PARAM_BM10BN10ThreadClusterBN101=" + - std::to_string(tunable.BM10BN10ThreadClusterBN101); - - out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + - std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]); - - out += - " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + - std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]); - - out += " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + - "," + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + - "," + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + - "," + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + - "," + - std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]); - - out += " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + - "," + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + - "," + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + - "," + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + - "," + - std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]); - - out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + - std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]); - - out += - " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + - std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]); - - out += " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + - "," + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + - "," + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + - "," + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + - "," + - std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]); - - out += " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + - "," + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + - "," + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + - "," + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + - "," + - std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]); - - out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + - std::to_string(tunable.CThreadTransferDstScalarPerVector); - - return (out); -}; - -} // namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw - -template -void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( - olCompile::Handle* handle, - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw; - using size_t = std::size_t; - - //////////////////////////////////////////////////////////////////////////////////////////////////////////// - // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, - // hasDoubleTailKBlockLoop - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - - const auto descs = - transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - tunable.GN0, - tunable.GK1); - - const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; - const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; - - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - const auto GK = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - - const index_t grid_size = (GM1 / tunable.GM1PerBlockGM11) * (GN1 / tunable.GN1PerBlockGN11); - const bool hasMainKBlockLoop = ((GK + tunable.GK0PerBlock) / (2 * tunable.GK0PerBlock) > 1); - const bool hasDoubleTailKBlockLoop = ((GK / tunable.GK0PerBlock) % 2 == 0); - - /////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // these buffers are usually provided by the user application - DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - // these are workspace buffers that should be expressed to the user by the corresponding - // workspace API - DeviceMem workspace_buf(4096); - - void* a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf = workspace_buf.GetDeviceBuffer(); - void* b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); - void* c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); - void* c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); - - const std::vector vld = {static_cast(tunable.BlockSize), 1, 1}; - const std::vector vgd1 = {static_cast(tunable.BlockSize), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * tunable.BlockSize), 1, 1}; - - std::string program_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v6r1_nchw"; - - std::string param = " -std=c++17 "; - std::string network_config; - - param += get_definition_string_from_types() + - " -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) + - " -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) + - get_definition_string_from_tunable(tunable); - - network_config = get_network_config_string_from_types() + "_" + - std::to_string(hasDoubleTailKBlockLoop) + "_" + - get_network_config_string_from_tunable(tunable); - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( - static_cast(in_n_c_hi_wi_lengths[I0]), - static_cast(in_n_c_hi_wi_lengths[I1]), - static_cast(in_n_c_hi_wi_lengths[I2]), - static_cast(in_n_c_hi_wi_lengths[I3]), - static_cast(wei_k_c_y_x_lengths[I0]), - static_cast(wei_k_c_y_x_lengths[I2]), - static_cast(wei_k_c_y_x_lengths[I3]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf, - b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf, - c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf); - timer2.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( - reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), - reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), - (const void*)(a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf), - (const void*)(b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf), - (const void*)(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf), - (const void*)(c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf)); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = Driver::get_effective_average(kernel1_times); - auto ave_time2 = Driver::get_effective_average(kernel2_times); - - const auto N = in_n_c_hi_wi_lengths[I0]; - const auto C = in_n_c_hi_wi_lengths[I1]; - - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_online/include/online_driver_common.hpp b/host/driver_online/include/online_driver_common.hpp index 383bf4c6a4..472ffb52dc 100644 --- a/host/driver_online/include/online_driver_common.hpp +++ b/host/driver_online/include/online_driver_common.hpp @@ -1,114 +1,44 @@ -#ifndef OLC_DRIVER_COMMON_HPP -#define OLC_DRIVER_COMMON_HPP +#ifndef ONLINE_DRIVER_COMMON_HPP +#define ONLINE_DRIVER_COMMON_HPP -#include -#include -#include +namespace ck_driver { -// this enumerate should be synchronized with include/miopen.h -typedef enum { - appHalf = 0, - appFloat = 1, - appInt32 = 2, - appInt8 = 3, - appInt8x4 = 4, - appBFloat16 = 5, - appDouble = 6, -} appDataType_t; - -namespace Driver { - -template -struct get_type_from_type_enum +// greatest common divisor, aka highest common factor +inline int gcd(int x, int y) { - using type = float; -}; - -template <> -struct get_type_from_type_enum -{ - using type = half_float::half; -}; - -template <> -struct get_type_from_type_enum -{ - using type = float; -}; - -template <> -struct get_type_from_type_enum -{ - using type = double; -}; - -template <> -struct get_type_from_type_enum -{ - using type = int; -}; - -static inline int get_typeid_from_type_enum(appDataType_t t) -{ - switch(t) + if(x < 0) { - case appHalf: return (static_cast('H')); - case appFloat: return (static_cast('F')); - case appBFloat16: return (static_cast('B')); - case appDouble: return (static_cast('D')); - case appInt8: - case appInt8x4: - case appInt32: return (static_cast('O')); - default: throw std::runtime_error("Only float, half, bfloat16 data type is supported."); break; - }; -}; - -template -static inline int get_typeid_from_type() -{ - throw std::runtime_error("Unsupported typeid conversion for this type!"); -}; - -template <> -inline int get_typeid_from_type() -{ - return (static_cast('F')); -}; - -template <> -inline int get_typeid_from_type() -{ - return (static_cast('H')); -}; - -template <> -inline int get_typeid_from_type() -{ - return (static_cast('D')); -}; - -static inline float get_effective_average(std::vector& values) -{ - assert(!values.empty()); - - if(values.size() == 1) - return (values[0]); + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } else { - float sum = 0.0f; - float maxVal = 0.0f; + return gcd(x, y % x); + } +} - for(const auto val : values) - { - if(maxVal < val) - maxVal = val; - sum += val; - }; - - return ((sum - maxVal) / (values.size() - 1)); - }; -}; - -} // namespace Driver +template = 2, bool>::type = false> +auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} +} // namespace ck_driver #endif diff --git a/host/online_compilation/CMakeLists.txt b/host/online_compilation/CMakeLists.txt index 7bbfc65288..c764917c28 100644 --- a/host/online_compilation/CMakeLists.txt +++ b/host/online_compilation/CMakeLists.txt @@ -77,6 +77,7 @@ message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") ## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}") +set(HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS} ${HIP_ONLINE_COMPILER_FLAGS}") file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp") file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp") diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index 9a02b68e07..e65c53ce1e 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -6,21 +6,16 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=../../../ MY_PROJECT_INSTALL=../install.dir -cmake \ --D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +cmake \ +-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX906 -O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX906" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps=$CWD" \ - #CXX_FLAG_TMP=-Weverything # -Wno-c++98-compat \ # -Wno-c++98-compat-pedantic \