mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Formatting
This commit is contained in:
@@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
|
||||
default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
|
||||
default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: BlockToCTileMap_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
__host__ __device__ constexpr bool
|
||||
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ struct Solution
|
||||
std::size_t grid_size;
|
||||
};
|
||||
|
||||
enum class DataType {
|
||||
enum class DataType
|
||||
{
|
||||
Half,
|
||||
Float,
|
||||
Int8,
|
||||
@@ -26,7 +27,7 @@ enum class DataType {
|
||||
|
||||
std::string ToString(DataType dt);
|
||||
|
||||
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders();
|
||||
std::unordered_map<std::string, std::pair<const char*, const char*>> GetHeaders();
|
||||
|
||||
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
|
||||
|
||||
|
||||
@@ -11,45 +11,44 @@
|
||||
#include <numeric>
|
||||
#include "ck/host/common.hpp"
|
||||
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransE = false;
|
||||
std::vector<bool> DsTrans = {};
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType EDataType = DataType::Half;
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransE = false;
|
||||
std::vector<bool> DsTrans = {};
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType EDataType = DataType::Half;
|
||||
std::vector<DataType> DsDataType = {};
|
||||
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string CDEElementOp = "ck::Tuple<>";
|
||||
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string CDEElementOp = "ck::Tuple<>";
|
||||
|
||||
static const std::size_t ds_layout_idx = 3;
|
||||
static const std::size_t ds_data_type_idx = 9;
|
||||
static const std::size_t e_data_type_idx = 10;
|
||||
static const std::size_t a_elementwise_op_idx = 11;
|
||||
static const std::size_t b_elementwise_op_idx = 12;
|
||||
static const std::size_t ds_layout_idx = 3;
|
||||
static const std::size_t ds_data_type_idx = 9;
|
||||
static const std::size_t e_data_type_idx = 10;
|
||||
static const std::size_t a_elementwise_op_idx = 11;
|
||||
static const std::size_t b_elementwise_op_idx = 12;
|
||||
static const std::size_t ds_elementwise_op_idx = 13;
|
||||
static const std::size_t gemm_spec_idx = 14;
|
||||
static const std::size_t block_size_idx = 16;
|
||||
static const std::size_t m_per_block_idx = 17;
|
||||
static const std::size_t n_per_block_idx = 18;
|
||||
static const std::size_t k_per_block_idx = 19;
|
||||
static const std::size_t gemm_spec_idx = 14;
|
||||
static const std::size_t block_size_idx = 16;
|
||||
static const std::size_t m_per_block_idx = 17;
|
||||
static const std::size_t n_per_block_idx = 18;
|
||||
static const std::size_t k_per_block_idx = 19;
|
||||
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
std::vector<Solution> GetSolutions(const std::string& arch) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
std::vector<std::string> GetInstances(const std::string& arch) const;
|
||||
|
||||
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
|
||||
|
||||
@@ -8,16 +8,17 @@ namespace host {
|
||||
|
||||
std::string ToString(DataType dt)
|
||||
{
|
||||
switch (dt) {
|
||||
case DataType::Float: return "float";
|
||||
case DataType::Half: return "ck::half_t";
|
||||
case DataType::Int8: return "int8_t";
|
||||
case DataType::Int32: return "int32_t";
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::Float: return "float";
|
||||
case DataType::Half: return "ck::half_t";
|
||||
case DataType::Int8: return "int8_t";
|
||||
case DataType::Int32: return "int32_t";
|
||||
}
|
||||
throw std::runtime_error("Incorrect data type");
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders()
|
||||
std::unordered_map<std::string, std::pair<const char*, const char*>> GetHeaders()
|
||||
{
|
||||
return ck_headers();
|
||||
}
|
||||
|
||||
@@ -8,12 +8,12 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block,
|
||||
const std::size_t k_per_block)
|
||||
const std::size_t k_per_block)
|
||||
{
|
||||
std::string spec = "";
|
||||
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
|
||||
@@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m,
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
|
||||
std::size_t GetGridSize(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block)
|
||||
std::size_t GetGridSize(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block)
|
||||
{
|
||||
return integer_divide_ceil(m, m_per_block) *
|
||||
integer_divide_ceil(n, n_per_block);
|
||||
return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
|
||||
}
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs()
|
||||
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
|
||||
{
|
||||
std::vector<std::string> instances;
|
||||
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
|
||||
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
|
||||
if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
|
||||
{
|
||||
ck::host::instance::gemm_add_add_fastgelu_instances all_instances{};
|
||||
if(TransA and TransB)
|
||||
@@ -65,27 +64,28 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
|
||||
std::string MakeLayoutTuple(const std::vector<bool>& layouts)
|
||||
{
|
||||
std::string layout_tuple = "ck::Tuple<";
|
||||
auto it = layouts.begin();
|
||||
auto it = layouts.begin();
|
||||
while(it != layouts.end())
|
||||
{
|
||||
layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
|
||||
layout_tuple +=
|
||||
*it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
|
||||
it = std::next(it);
|
||||
if (it != layouts.end())
|
||||
if(it != layouts.end())
|
||||
layout_tuple += ", ";
|
||||
}
|
||||
|
||||
|
||||
return layout_tuple + ">";
|
||||
}
|
||||
|
||||
std::string MakeTypeTuple(const std::vector<DataType>& types)
|
||||
{
|
||||
std::string type_tuple = "ck::Tuple<";
|
||||
auto it = types.begin();
|
||||
auto it = types.begin();
|
||||
while(it != types.end())
|
||||
{
|
||||
type_tuple += ToString(*it);
|
||||
it = std::next(it);
|
||||
if (it != types.end())
|
||||
if(it != types.end())
|
||||
type_tuple += ", ";
|
||||
}
|
||||
return type_tuple + ">";
|
||||
@@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
|
||||
std::istringstream iss(template_str);
|
||||
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
|
||||
std::istream_iterator<std::string>());
|
||||
|
||||
if (ADataType == DataType::Int8 and BDataType == DataType::Int8)
|
||||
|
||||
if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
|
||||
{
|
||||
// Change CBlockTransfer ScalarPerVector if Ds contains other types
|
||||
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
|
||||
if(std::any_of(
|
||||
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
|
||||
{
|
||||
params[params.size() - 3] = "8";
|
||||
}
|
||||
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
|
||||
if(std::any_of(
|
||||
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
|
||||
{
|
||||
params[params.size() - 3] = "4";
|
||||
}
|
||||
}
|
||||
|
||||
params[a_elementwise_op_idx] = AElementOp;
|
||||
params[b_elementwise_op_idx] = BElementOp;
|
||||
params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
|
||||
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
|
||||
params[a_elementwise_op_idx] = AElementOp;
|
||||
params[b_elementwise_op_idx] = BElementOp;
|
||||
params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
|
||||
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
|
||||
params[ds_elementwise_op_idx] = CDEElementOp;
|
||||
params[e_data_type_idx] = ToString(EDataType);
|
||||
auto block_size_str = params[block_size_idx];
|
||||
auto m_per_block_str = params[m_per_block_idx];
|
||||
auto n_per_block_str = params[n_per_block_idx];
|
||||
auto k_per_block_str = params[k_per_block_idx];
|
||||
params[e_data_type_idx] = ToString(EDataType);
|
||||
auto block_size_str = params[block_size_idx];
|
||||
auto m_per_block_str = params[m_per_block_idx];
|
||||
auto n_per_block_str = params[n_per_block_idx];
|
||||
auto k_per_block_str = params[k_per_block_idx];
|
||||
const std::size_t block_size = std::stoi(block_size_str);
|
||||
const std::size_t m_per_block = std::stoi(m_per_block_str);
|
||||
const std::size_t n_per_block = std::stoi(n_per_block_str);
|
||||
const std::size_t k_per_block = std::stoi(k_per_block_str);
|
||||
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
|
||||
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
|
||||
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
|
||||
|
||||
std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{},
|
||||
[](const std::string& a, const std::string& b) {
|
||||
return a.empty() ? b : a + ", " + b;
|
||||
});
|
||||
std::string str = std::accumulate(
|
||||
params.begin() + 1,
|
||||
params.end(),
|
||||
std::string{},
|
||||
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
|
||||
str = params.front() + "< " + str + ">";
|
||||
|
||||
|
||||
return Solution{str, block_size, grid_size};
|
||||
}
|
||||
|
||||
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
|
||||
{
|
||||
std::vector<Solution> solutions;
|
||||
const std::size_t num_instances = GetInstances(arch).size();
|
||||
for (std::size_t i = 0; i < num_instances; ++i)
|
||||
for(std::size_t i = 0; i < num_instances; ++i)
|
||||
{
|
||||
solutions.push_back(MakeSolution(i, arch));
|
||||
}
|
||||
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
|
||||
return solutions;
|
||||
}
|
||||
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
|
||||
@@ -3,36 +3,48 @@
|
||||
|
||||
bool test_Problem()
|
||||
{
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
|
||||
const auto include_header = problem.GetIncludeHeader();
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
const auto grid_size = solution.grid_size;
|
||||
const auto block_size = solution.block_size;
|
||||
const auto include_header = problem.GetIncludeHeader();
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
const auto grid_size = solution.grid_size;
|
||||
const auto block_size = solution.block_size;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass &= include_header == "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
|
||||
pass &= include_header ==
|
||||
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
|
||||
pass &= solutions.size() == 42;
|
||||
pass &= template_str == "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
|
||||
pass &= grid_size == 2;
|
||||
pass &= template_str ==
|
||||
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
|
||||
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
|
||||
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
|
||||
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
|
||||
"ck::tensor_operation::element_wise::Passthrough, "
|
||||
"ck::tensor_operation::element_wise::Passthrough, "
|
||||
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
|
||||
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
|
||||
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
|
||||
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
|
||||
pass &= grid_size == 2;
|
||||
pass &= block_size == 256;
|
||||
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -40,46 +52,48 @@ bool test_GetGemmSpec()
|
||||
{
|
||||
bool pass = true;
|
||||
{
|
||||
//PadMNK
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{255,
|
||||
255,
|
||||
255,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
// PadMNK
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
255,
|
||||
255,
|
||||
255,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
|
||||
pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos;
|
||||
}
|
||||
{
|
||||
//Default
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
// Default
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
|
||||
pass &= template_str.find("GemmSpecialization::Default") != std::string::npos;
|
||||
}
|
||||
@@ -91,147 +105,155 @@ bool test_GetInstances()
|
||||
{
|
||||
bool pass = true;
|
||||
{
|
||||
//Col Col Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Col Col Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 51;
|
||||
}
|
||||
{
|
||||
//Col Row Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Col Row Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 51;
|
||||
}
|
||||
{
|
||||
//Row Col Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Row Col Fp16
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 42;
|
||||
}
|
||||
{
|
||||
//Row Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Row Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 48;
|
||||
}
|
||||
{
|
||||
//Col Col Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Col Col Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 48;
|
||||
}
|
||||
{
|
||||
//Col Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Col Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 48;
|
||||
}
|
||||
{
|
||||
//Row Col Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Row Col Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 39;
|
||||
}
|
||||
{
|
||||
//Row Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
// Row Row Int8
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Int8,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
pass &= problem.GetSolutions("gfx90a").size() == 48;
|
||||
}
|
||||
|
||||
@@ -243,45 +265,50 @@ bool test_MakeLayoutsTuple()
|
||||
bool pass = true;
|
||||
{
|
||||
// Empty Tuple
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
|
||||
}
|
||||
{
|
||||
// RowColRow Tuple
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{false, true, false},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
pass &= template_str.find("ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") != std::string::npos;
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{false, true, false},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
pass &= template_str.find(
|
||||
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
|
||||
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") !=
|
||||
std::string::npos;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -292,44 +319,46 @@ bool test_MakeTypeTuple()
|
||||
bool pass = true;
|
||||
{
|
||||
// Empty Tuple
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{true},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{true},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
|
||||
}
|
||||
{
|
||||
// Half Int8 Tuple
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half, ck::host::DataType::Int8},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
auto problem = ck::host::device_gemm_multiple_d::Problem{
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{},
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
ck::host::DataType::Half,
|
||||
{ck::host::DataType::Half, ck::host::DataType::Int8},
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough",
|
||||
"ck::tensor_operation::element_wise::Passthrough"};
|
||||
const auto solutions = problem.GetSolutions("gfx90a");
|
||||
const auto& solution = solutions.at(0);
|
||||
const auto template_str = solution.template_str;
|
||||
pass &= template_str.find("ck::Tuple<ck::half_t, int8_t>") != std::string::npos;
|
||||
}
|
||||
return pass;
|
||||
|
||||
Reference in New Issue
Block a user