diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 9312362845..2e5908270b 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; - __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = - default; - __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = - default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + BlockToCTileMap_M00_N0_M01Adapt&&) = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; - __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { } template - __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) : BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) { @@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ __device__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) { return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); } template - __host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ __device__ constexpr bool + CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } @@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt template __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const + const CTileDim& /* c_tile_dim */) const { return true; // always valid provided that user gets grid size from CalculateGridSize() } diff --git a/library/src/jit_library/include/ck/host/common.hpp b/library/src/jit_library/include/ck/host/common.hpp index 8b2ceacc68..9a6995eb8f 100644 --- a/library/src/jit_library/include/ck/host/common.hpp +++ b/library/src/jit_library/include/ck/host/common.hpp @@ -17,7 +17,8 @@ struct Solution std::size_t grid_size; }; -enum class DataType { +enum class DataType +{ Half, Float, Int8, @@ -26,7 +27,7 @@ enum class DataType { std::string ToString(DataType dt); -std::unordered_map> GetHeaders(); +std::unordered_map> GetHeaders(); std::size_t integer_divide_ceil(std::size_t x, std::size_t y); diff --git a/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp b/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp index c73715e1ae..0d935192cf 100644 --- a/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp +++ b/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp @@ -11,45 +11,44 @@ #include #include "ck/host/common.hpp" - namespace ck { namespace host { namespace device_gemm_multiple_d { struct Problem { - std::size_t M = 0; - std::size_t N = 0; - std::size_t K = 0; - bool TransA = false; - bool TransB = false; - bool TransE = false; - std::vector DsTrans = {}; - DataType ADataType = DataType::Half; - DataType BDataType = DataType::Half; - DataType EDataType = DataType::Half; + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; std::vector DsDataType = {}; - std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; - std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; - std::string CDEElementOp = "ck::Tuple<>"; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::Tuple<>"; - static const std::size_t ds_layout_idx = 3; - static const std::size_t ds_data_type_idx = 9; - static const std::size_t e_data_type_idx = 10; - static const std::size_t a_elementwise_op_idx = 11; - static const std::size_t b_elementwise_op_idx = 12; + static const std::size_t ds_layout_idx = 3; + static const std::size_t ds_data_type_idx = 9; + static const std::size_t e_data_type_idx = 10; + static const std::size_t a_elementwise_op_idx = 11; + static const std::size_t b_elementwise_op_idx = 12; static const std::size_t ds_elementwise_op_idx = 13; - static const std::size_t gemm_spec_idx = 14; - static const std::size_t block_size_idx = 16; - static const std::size_t m_per_block_idx = 17; - static const std::size_t n_per_block_idx = 18; - static const std::size_t k_per_block_idx = 19; + static const std::size_t gemm_spec_idx = 14; + static const std::size_t block_size_idx = 16; + static const std::size_t m_per_block_idx = 17; + static const std::size_t n_per_block_idx = 18; + static const std::size_t k_per_block_idx = 19; std::string GetIncludeHeader() const; std::vector GetSolutions(const std::string& arch) const; -private: + private: std::vector GetInstances(const std::string& arch) const; Solution MakeSolution(std::size_t idx, const std::string& arch) const; diff --git a/library/src/jit_library/src/common.cpp b/library/src/jit_library/src/common.cpp index 067b60df68..f270e8be21 100644 --- a/library/src/jit_library/src/common.cpp +++ b/library/src/jit_library/src/common.cpp @@ -8,16 +8,17 @@ namespace host { std::string ToString(DataType dt) { - switch (dt) { - case DataType::Float: return "float"; - case DataType::Half: return "ck::half_t"; - case DataType::Int8: return "int8_t"; - case DataType::Int32: return "int32_t"; + switch(dt) + { + case DataType::Float: return "float"; + case DataType::Half: return "ck::half_t"; + case DataType::Int8: return "int8_t"; + case DataType::Int32: return "int32_t"; } throw std::runtime_error("Incorrect data type"); } -std::unordered_map> GetHeaders() +std::unordered_map> GetHeaders() { return ck_headers(); } diff --git a/library/src/jit_library/src/device_gemm_multiple_d.cpp b/library/src/jit_library/src/device_gemm_multiple_d.cpp index 33d15af664..09393760fe 100644 --- a/library/src/jit_library/src/device_gemm_multiple_d.cpp +++ b/library/src/jit_library/src/device_gemm_multiple_d.cpp @@ -8,12 +8,12 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { -std::string GetGemmSpec(const std::size_t m, - const std::size_t n, +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, const std::size_t k, const std::size_t m_per_block, const std::size_t n_per_block, - const std::size_t k_per_block) + const std::size_t k_per_block) { std::string spec = ""; if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) @@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m, return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; } -std::size_t GetGridSize(const std::size_t m, - const std::size_t n, - const std::size_t m_per_block, - const std::size_t n_per_block) +std::size_t GetGridSize(const std::size_t m, + const std::size_t n, + const std::size_t m_per_block, + const std::size_t n_per_block) { - return integer_divide_ceil(m, m_per_block) * - integer_divide_ceil(n, n_per_block); + return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block); } const std::unordered_set& get_xdlop_archs() @@ -47,7 +46,7 @@ std::vector Problem::GetInstances(const std::string& arch) const { std::vector instances; const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8; - if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) + if(get_xdlop_archs().find(arch) != get_xdlop_archs().end()) { ck::host::instance::gemm_add_add_fastgelu_instances all_instances{}; if(TransA and TransB) @@ -65,27 +64,28 @@ std::vector Problem::GetInstances(const std::string& arch) const std::string MakeLayoutTuple(const std::vector& layouts) { std::string layout_tuple = "ck::Tuple<"; - auto it = layouts.begin(); + auto it = layouts.begin(); while(it != layouts.end()) { - layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; + layout_tuple += + *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; it = std::next(it); - if (it != layouts.end()) + if(it != layouts.end()) layout_tuple += ", "; } - + return layout_tuple + ">"; } std::string MakeTypeTuple(const std::vector& types) { std::string type_tuple = "ck::Tuple<"; - auto it = types.begin(); + auto it = types.begin(); while(it != types.end()) { type_tuple += ToString(*it); it = std::next(it); - if (it != types.end()) + if(it != types.end()) type_tuple += ", "; } return type_tuple + ">"; @@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const std::istringstream iss(template_str); std::vector params(std::istream_iterator{iss}, std::istream_iterator()); - - if (ADataType == DataType::Int8 and BDataType == DataType::Int8) + + if(ADataType == DataType::Int8 and BDataType == DataType::Int8) { // Change CBlockTransfer ScalarPerVector if Ds contains other types - if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) + if(std::any_of( + DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) { params[params.size() - 3] = "8"; } - if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) + if(std::any_of( + DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) { params[params.size() - 3] = "4"; } } - params[a_elementwise_op_idx] = AElementOp; - params[b_elementwise_op_idx] = BElementOp; - params[ds_layout_idx] = MakeLayoutTuple(DsTrans); - params[ds_data_type_idx] = MakeTypeTuple(DsDataType); + params[a_elementwise_op_idx] = AElementOp; + params[b_elementwise_op_idx] = BElementOp; + params[ds_layout_idx] = MakeLayoutTuple(DsTrans); + params[ds_data_type_idx] = MakeTypeTuple(DsDataType); params[ds_elementwise_op_idx] = CDEElementOp; - params[e_data_type_idx] = ToString(EDataType); - auto block_size_str = params[block_size_idx]; - auto m_per_block_str = params[m_per_block_idx]; - auto n_per_block_str = params[n_per_block_idx]; - auto k_per_block_str = params[k_per_block_idx]; + params[e_data_type_idx] = ToString(EDataType); + auto block_size_str = params[block_size_idx]; + auto m_per_block_str = params[m_per_block_idx]; + auto n_per_block_str = params[n_per_block_idx]; + auto k_per_block_str = params[k_per_block_idx]; const std::size_t block_size = std::stoi(block_size_str); const std::size_t m_per_block = std::stoi(m_per_block_str); const std::size_t n_per_block = std::stoi(n_per_block_str); const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block); - params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); + params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); - std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{}, - [](const std::string& a, const std::string& b) { - return a.empty() ? b : a + ", " + b; - }); + std::string str = std::accumulate( + params.begin() + 1, + params.end(), + std::string{}, + [](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; }); str = params.front() + "< " + str + ">"; - + return Solution{str, block_size, grid_size}; } @@ -146,7 +149,7 @@ std::vector Problem::GetSolutions(const std::string& arch) const { std::vector solutions; const std::size_t num_instances = GetInstances(arch).size(); - for (std::size_t i = 0; i < num_instances; ++i) + for(std::size_t i = 0; i < num_instances; ++i) { solutions.push_back(MakeSolution(i, arch)); } @@ -154,7 +157,6 @@ std::vector Problem::GetSolutions(const std::string& arch) const return solutions; } - } // namespace device_gemm_multiple_d } // namespace host } // namespace ck diff --git a/test/jit_library/jit_library.cpp b/test/jit_library/jit_library.cpp index 136eb60ddd..2e05cd3a50 100644 --- a/test/jit_library/jit_library.cpp +++ b/test/jit_library/jit_library.cpp @@ -3,36 +3,48 @@ bool test_Problem() { - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - true, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + true, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; - const auto grid_size = solution.grid_size; - const auto block_size = solution.block_size; + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; + const auto grid_size = solution.grid_size; + const auto block_size = solution.block_size; bool pass = true; - pass &= include_header == "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; + pass &= include_header == + "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; pass &= solutions.size() == 42; - pass &= template_str == "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"; - pass &= grid_size == 2; + pass &= template_str == + "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< " + "ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, " + "ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, " + "ck::half_t, ck::tensor_operation::element_wise::Passthrough, " + "ck::tensor_operation::element_wise::Passthrough, " + "ck::tensor_operation::element_wise::Passthrough, " + "ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, " + "8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, " + "8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, " + "1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"; + pass &= grid_size == 2; pass &= block_size == 256; - + return pass; } @@ -40,46 +52,48 @@ bool test_GetGemmSpec() { bool pass = true; { - //PadMNK - auto problem = ck::host::device_gemm_multiple_d::Problem{255, - 255, - 255, - false, - true, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; + // PadMNK + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 255, + 255, + 255, + false, + true, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos; } { - //Default - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - true, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; + // Default + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + true, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; pass &= template_str.find("GemmSpecialization::Default") != std::string::npos; } @@ -91,147 +105,155 @@ bool test_GetInstances() { bool pass = true; { - //Col Col Fp16 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - true, - true, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Col Col Fp16 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + true, + true, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 51; } { - //Col Row Fp16 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - true, - false, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Col Row Fp16 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + true, + false, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 51; } { - //Row Col Fp16 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - true, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Row Col Fp16 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + true, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 42; } { - //Row Row Int8 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {}, - ck::host::DataType::Int8, - ck::host::DataType::Int8, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Row Row Int8 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {}, + ck::host::DataType::Int8, + ck::host::DataType::Int8, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 48; } { - //Col Col Int8 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - true, - true, - false, - {}, - ck::host::DataType::Int8, - ck::host::DataType::Int8, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Col Col Int8 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + true, + true, + false, + {}, + ck::host::DataType::Int8, + ck::host::DataType::Int8, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 48; } { - //Col Row Int8 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - true, - false, - false, - {}, - ck::host::DataType::Int8, - ck::host::DataType::Int8, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Col Row Int8 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + true, + false, + false, + {}, + ck::host::DataType::Int8, + ck::host::DataType::Int8, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 48; } { - //Row Col Int8 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - true, - false, - {}, - ck::host::DataType::Int8, - ck::host::DataType::Int8, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Row Col Int8 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + true, + false, + {}, + ck::host::DataType::Int8, + ck::host::DataType::Int8, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 39; } { - //Row Row Int8 - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {}, - ck::host::DataType::Int8, - ck::host::DataType::Int8, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; + // Row Row Int8 + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {}, + ck::host::DataType::Int8, + ck::host::DataType::Int8, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; pass &= problem.GetSolutions("gfx90a").size() == 48; } @@ -243,45 +265,50 @@ bool test_MakeLayoutsTuple() bool pass = true; { // Empty Tuple - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {ck::host::DataType::Half}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {ck::host::DataType::Half}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; pass &= template_str.find("ck::Tuple<>") != std::string::npos; } { // RowColRow Tuple - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {false, true, false}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {ck::host::DataType::Half}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; - pass &= template_str.find("ck::Tuple") != std::string::npos; + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {false, true, false}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {ck::host::DataType::Half}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; + pass &= template_str.find( + "ck::Tuple") != + std::string::npos; } return pass; @@ -292,44 +319,46 @@ bool test_MakeTypeTuple() bool pass = true; { // Empty Tuple - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {true}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {true}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; pass &= template_str.find("ck::Tuple<>") != std::string::npos; } { // Half Int8 Tuple - auto problem = ck::host::device_gemm_multiple_d::Problem{256, - 256, - 256, - false, - false, - false, - {}, - ck::host::DataType::Half, - ck::host::DataType::Half, - ck::host::DataType::Half, - {ck::host::DataType::Half, ck::host::DataType::Int8}, - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough", - "ck::tensor_operation::element_wise::Passthrough"}; - const auto solutions = problem.GetSolutions("gfx90a"); - const auto& solution = solutions.at(0); - const auto template_str = solution.template_str; + auto problem = ck::host::device_gemm_multiple_d::Problem{ + 256, + 256, + 256, + false, + false, + false, + {}, + ck::host::DataType::Half, + ck::host::DataType::Half, + ck::host::DataType::Half, + {ck::host::DataType::Half, ck::host::DataType::Int8}, + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough", + "ck::tensor_operation::element_wise::Passthrough"}; + const auto solutions = problem.GetSolutions("gfx90a"); + const auto& solution = solutions.at(0); + const auto template_str = solution.template_str; pass &= template_str.find("ck::Tuple") != std::string::npos; } return pass;