mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
[CK_TILE] working version
This commit is contained in:
@@ -311,9 +311,9 @@ struct CShuffleEpilogue
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
using SFC = space_filling_curve<sequence<kNPerBlock, kMPerBlock>,
|
||||
sequence<1, 0>,
|
||||
sequence<NPerIterationShuffle, MPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
@@ -667,7 +667,7 @@ struct CShuffleEpilogue
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
//print(LdsTileDistr);
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
@@ -736,8 +736,8 @@ struct CShuffleEpilogue
|
||||
static_assert(GetVectorSizeC() > 1, "VectorSizeC is not greater than 1!");
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
MPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
|
||||
@@ -21,7 +21,7 @@ struct TileGemmTraits
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = 2;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
@@ -49,7 +49,7 @@ struct TileGemmUniversalTraits
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = 2;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
|
||||
@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
@@ -78,10 +79,21 @@ using KernelTypesMemWmma = ::testing::Types<
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
<<<<<<< HEAD
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
|
||||
//std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
=======
|
||||
std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>//,
|
||||
//std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
>>>>>>> 0a9ceadca ([CK_TILE] working version)
|
||||
//std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
//std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
|
||||
|
||||
@@ -334,10 +334,45 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
@@ -348,9 +383,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
std::cout << "c_m_n_dev_result: ";
|
||||
c_m_n_dev_result.print_first_n(std::cout) << '\n';
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{1, 2, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{1, 2, 11940}(b_k_n);
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
|
||||
|
||||
//ck_tile::FillConstant<ADataType>{1}(a_m_k);
|
||||
//ck_tile::FillConstant<BDataType>{2}(b_k_n);
|
||||
// FillConstant
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
Reference in New Issue
Block a user