mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
[CK_TILE] working version and tests
This commit is contained in:
@@ -79,6 +79,7 @@ using KernelTypesMemWmma = ::testing::Types<
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
//std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
@@ -130,6 +131,20 @@ using KernelTypesCompV3 = ::testing::Types<
|
||||
//std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
=======
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
>>>>>>> 12c48382b ([CK_TILE] working version and tests)
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
|
||||
@@ -343,6 +343,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor_out = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz});
|
||||
}
|
||||
@@ -386,9 +400,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
||||
|
||||
//ck_tile::FillConstant<ADataType>{1}(a_m_k);
|
||||
//ck_tile::FillConstant<BDataType>{2}(b_k_n);
|
||||
// FillConstant
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
Reference in New Issue
Block a user