From 6159c579becf86da2297547f3dd21388070f7425 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Mon, 4 Jul 2022 22:55:14 +0800 Subject: [PATCH] Use language construct to choose between types --- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 33 ++++++++++--------- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 33 ++++++++++--------- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 13 +++++--- 3 files changed, 44 insertions(+), 35 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index cd0a7743c5..8092258435 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -16,6 +16,8 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include + namespace ck { template ; -#if 0 - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; -#else - using GridwiseGemmPipe = GridwiseGemmPipeline_v2; -#endif + static constexpr std::size_t GridwiseGemmPipelineVersion = 2; + + using GridwiseGemmPipe = typename std::tuple_element< + GridwiseGemmPipelineVersion, + std::tuple, GridwiseGemmPipeline_v2>>:: + type; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index a11e955387..2b34f2655c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -15,6 +15,8 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include + namespace ck { template ; -#if 0 - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; -#else - using GridwiseGemmPipe = GridwiseGemmPipeline_v2; -#endif + static constexpr std::size_t GridwiseGemmPipelineVersion = 2; + + using GridwiseGemmPipe = typename std::tuple_element< + GridwiseGemmPipelineVersion, + std::tuple, GridwiseGemmPipeline_v2>>:: + type; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 347af0b17f..8c9a715bdc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -16,6 +16,8 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include + namespace ck { template ; -#if 0 - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; -#else - using GridwiseGemmPipe = GridwiseGemmPipeline_v2; -#endif + static constexpr std::size_t GridwiseGemmPipelineVersion = 2; + + using GridwiseGemmPipe = typename std::tuple_element< + GridwiseGemmPipelineVersion, + std::tuple, GridwiseGemmPipeline_v2>>:: + type; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() {