Merge commit 'f5c2f09036cdc22dc8944719215dd47003c50a24' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-24 00:38:47 +00:00
parent e2e058bcbc
commit 6a21c125a0
6 changed files with 115 additions and 16 deletions

View File

@@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
/**
* @brief Calculates the total space needed for the flags buffer.
* @brief Calculates the total space needed for the flags buffer whose total byte size is
* 128B-aligned.
*
* @return index_t The number of bytes needed for the flags buffer.
*/

View File

@@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
const noexcept
{
return sizeof(index_t) * sk_ctas_;
constexpr index_t alignment = 128;
const index_t required_bytes = sizeof(index_t) * sk_ctas_;
const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment);
return padded_bytes;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>

View File

@@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
# TODO: Renable once transient bug for reduction is resolved.
# add_gtest_executable(test_ck_tile_streamk_reduction
# ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
# test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp

View File

@@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
// Calculate reference GEMM on the GPU
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
ref_c_m_n_dev_buf.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(a_m_k_dev_ref_ptr,
b_k_n_dev_ref_ptr,
c_m_n_dev_ref_ptr,
M,
N,
K,
stride_A,
stride_B,
stride_C);
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
*std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_accumulations_per_tile, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
c_m_n_dev_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));

View File

@@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
{
using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Reduction>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
}
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
@@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
ck_tile::index_t expected_partials_size =
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID;
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
// the flags array is 128B-aligned.
ck_tile::index_t expected_flags_size = 128;
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
expected_partials_size + expected_flags_size);

View File

@@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig
struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
// The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure
// the total byte size of the array is 128B-aligned, the flags array must be 128B.
static constexpr ck_tile::index_t GRID = 3;
static constexpr ck_tile::index_t M_TILE = 4;
@@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 32;
// The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the
// number of bytes for the flags array should be 128B.
static constexpr ck_tile::index_t GRID = 32;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
static constexpr ck_tile::index_t K_TILE = 1;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 33;
// The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the
// number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size
// of the array is 128B-aligned.
static constexpr ck_tile::index_t GRID = 33;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
static constexpr ck_tile::index_t K_TILE = 1;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
: public StreamKTilePartitionerBaseConfig
{