mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
[CK_TILE] TEST vector stores c col layout part1
This commit is contained in:
@@ -421,10 +421,11 @@ struct CShuffleEpilogue
|
||||
template <index_t iAccess, typename OAccTile, typename LdsTile>
|
||||
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
|
||||
{
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
|
||||
constexpr auto idx_start = SFC::get_index(number<iAccess>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
// SFC with (N,M) dims and (1,0) access order returns indices in (M, N) iteration order
|
||||
constexpr auto mIter = number<idx_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
@@ -736,8 +737,8 @@ struct CShuffleEpilogue
|
||||
static_assert(GetVectorSizeC() > 1, "VectorSizeC is not greater than 1!");
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
NPerIterationShuffle,
|
||||
MPerIterationShuffle,
|
||||
YPerIterationShuffle,
|
||||
XPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
|
||||
@@ -80,6 +80,7 @@ using KernelTypesMemWmma = ::testing::Types<
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
//std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
@@ -134,9 +135,12 @@ using KernelTypesCompV3 = ::testing::Types<
|
||||
=======
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
=======
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
>>>>>>> 05ce4e524 ([CK_TILE] TEST vector stores c col layout part1)
|
||||
std::tuple< Row, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
<<<<<<< HEAD
|
||||
std::tuple< Col, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
@@ -145,6 +149,19 @@ using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Col, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
>>>>>>> 12c48382b ([CK_TILE] working version and tests)
|
||||
=======
|
||||
std::tuple< Col, Col, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
//
|
||||
//std::tuple< Row, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Row, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Col, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//
|
||||
//std::tuple< Row, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Row, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Col, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
>>>>>>> 05ce4e524 ([CK_TILE] TEST vector stores c col layout part1)
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3Wmma = ::testing::Types<
|
||||
|
||||
@@ -343,20 +343,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor_out = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz});
|
||||
}
|
||||
@@ -388,14 +374,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
<<<<<<< HEAD
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
=======
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
>>>>>>> 05ce4e524 ([CK_TILE] TEST vector stores c col layout part1)
|
||||
|
||||
std::cout << "a_m_k: ";
|
||||
a_m_k.print_first_n(std::cout) << '\n';
|
||||
std::cout << "b_k_n: ";
|
||||
b_k_n.print_first_n(std::cout) << '\n';
|
||||
std::cout << "c_m_n_dev_result: ";
|
||||
c_m_n_dev_result.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "a_m_k: ";
|
||||
//a_m_k.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "b_k_n: ";
|
||||
//b_k_n.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "c_m_n_dev_result: ";
|
||||
//c_m_n_dev_result.print_first_n(std::cout) << '\n';
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
||||
@@ -443,12 +433,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
std::cout << "a_m_k: ";
|
||||
a_m_k.print_first_n(std::cout) << '\n';
|
||||
std::cout << "b_k_n: ";
|
||||
b_k_n.print_first_n(std::cout) << '\n';
|
||||
std::cout << "c_m_n_dev_result: ";
|
||||
c_m_n_dev_result.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "a_m_k: ";
|
||||
//a_m_k.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "b_k_n: ";
|
||||
//b_k_n.print_first_n(std::cout) << '\n';
|
||||
//std::cout << "c_m_n_dev_result: ";
|
||||
//c_m_n_dev_result.print_first_n(std::cout) << '\n';
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
|
||||
Reference in New Issue
Block a user