Use language construct to choose between types

This commit is contained in:
Po-Yen, Chen
2022-07-04 22:55:14 +08:00
parent 3edcd5fc23
commit 6159c579be
3 changed files with 44 additions and 35 deletions

View File

@@ -16,6 +16,8 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <tuple>
namespace ck {
template <typename GridwiseGemm,
@@ -60,16 +62,16 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -135,11 +137,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if 0
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
#endif
static constexpr std::size_t GridwiseGemmPipelineVersion = 2;
using GridwiseGemmPipe = typename std::tuple_element<
GridwiseGemmPipelineVersion,
std::tuple<char, GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>, GridwiseGemmPipeline_v2>>::
type;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{

View File

@@ -15,6 +15,8 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <tuple>
namespace ck {
template <typename GridwiseGemm,
@@ -59,16 +61,16 @@ __global__ void
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -127,11 +129,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if 0
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
#endif
static constexpr std::size_t GridwiseGemmPipelineVersion = 2;
using GridwiseGemmPipe = typename std::tuple_element<
GridwiseGemmPipelineVersion,
std::tuple<char, GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>, GridwiseGemmPipeline_v2>>::
type;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{

View File

@@ -16,6 +16,8 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <tuple>
namespace ck {
template <typename GridwiseGemm,
@@ -141,11 +143,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if 0
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
#endif
static constexpr std::size_t GridwiseGemmPipelineVersion = 2;
using GridwiseGemmPipe = typename std::tuple_element<
GridwiseGemmPipelineVersion,
std::tuple<char, GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>, GridwiseGemmPipeline_v2>>::
type;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{