mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit 'f5c2f09036cdc22dc8944719215dd47003c50a24' into develop
This commit is contained in:
@@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the total space needed for the flags buffer.
|
||||
* @brief Calculates the total space needed for the flags buffer whose total byte size is
|
||||
* 128B-aligned.
|
||||
*
|
||||
* @return index_t The number of bytes needed for the flags buffer.
|
||||
*/
|
||||
|
||||
@@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
|
||||
const noexcept
|
||||
{
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
constexpr index_t alignment = 128;
|
||||
const index_t required_bytes = sizeof(index_t) * sk_ctas_;
|
||||
const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment);
|
||||
return padded_bytes;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
|
||||
@@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
# TODO: Renable once transient bug for reduction is resolved.
|
||||
# add_gtest_executable(test_ck_tile_streamk_reduction
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
|
||||
# test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reduction
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
|
||||
@@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
// Calculate reference GEMM on the GPU
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
|
||||
ref_c_m_n_dev_buf.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(a_m_k_dev_ref_ptr,
|
||||
b_k_n_dev_ref_ptr,
|
||||
c_m_n_dev_ref_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
*std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
c_m_n_dev_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
@@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
@@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
|
||||
|
||||
ck_tile::index_t expected_partials_size =
|
||||
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
|
||||
ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID;
|
||||
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
|
||||
// the flags array is 128B-aligned.
|
||||
ck_tile::index_t expected_flags_size = 128;
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
|
||||
expected_partials_size + expected_flags_size);
|
||||
|
||||
@@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure
|
||||
// the total byte size of the array is 128B-aligned, the flags array must be 128B.
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
@@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes
|
||||
: public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 32;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the
|
||||
// number of bytes for the flags array should be 128B.
|
||||
static constexpr ck_tile::index_t GRID = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 1;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
|
||||
: public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 33;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the
|
||||
// number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size
|
||||
// of the array is 128B-aligned.
|
||||
static constexpr ck_tile::index_t GRID = 33;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 1;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
|
||||
: public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user