[CK_TILE] working version

This commit is contained in:
Aleksander Dudek
2025-10-22 09:45:30 -05:00
parent c1e3b96609
commit 7785fdb3a6
4 changed files with 63 additions and 13 deletions

View File

@@ -311,9 +311,9 @@ struct CShuffleEpilogue
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
using SFC = space_filling_curve<sequence<kNPerBlock, kMPerBlock>,
sequence<1, 0>,
sequence<NPerIterationShuffle, MPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
@@ -667,7 +667,7 @@ struct CShuffleEpilogue
const ScaleN& scale_n = {})
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
//print(LdsTileDistr);
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
@@ -736,8 +736,8 @@ struct CShuffleEpilogue
static_assert(GetVectorSizeC() > 1, "VectorSizeC is not greater than 1!");
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
MPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;

View File

@@ -21,7 +21,7 @@ struct TileGemmTraits
static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = 2;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
@@ -49,7 +49,7 @@ struct TileGemmUniversalTraits
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = 2;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using AsLayout = AsLayout_;

View File

@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I64 = ck_tile::number<64>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
// clang-format off
@@ -78,10 +79,21 @@ using KernelTypesMemWmma = ::testing::Types<
>;
using KernelTypesCompV3 = ::testing::Types<
<<<<<<< HEAD
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
//std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
=======
std::tuple< Row, Row, Col, F16, F16, F32, F16, I128, I128, I64, I32, I32, I16, Intrawave, CompV3>//,
//std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
>>>>>>> 0a9ceadca ([CK_TILE] working version)
//std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,

View File

@@ -334,10 +334,45 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({col, row}, {stride, 1_uz});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
@@ -348,9 +383,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
std::cout << "c_m_n_dev_result: ";
c_m_n_dev_result.print_first_n(std::cout) << '\n';
ck_tile::FillUniformDistributionIntegerValue<ADataType>{1, 2, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{1, 2, 11940}(b_k_n);
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
//ck_tile::FillConstant<ADataType>{1}(a_m_k);
//ck_tile::FillConstant<BDataType>{2}(b_k_n);
// FillConstant
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());