mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Stream-K Tile Partitioner Base Class with Tests
To better align with the original Stream-K paper, this change implements a new Stream-K tile partitioner base class. This class will handle the Stream-K setup that is common to both a persistent and non-persistent DP section. A later change will implement derived classes to handle the differences between persistent and non-persistent DP. This change also includes unit tests for the base tile partitioner.
This commit is contained in:
committed by
Emily Martins
parent
2d1c9e28e2
commit
f87f768d16
@@ -4,7 +4,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current stream-k c-shuffle only supports C layout as R
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
@@ -116,6 +116,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
|
||||
349
test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp
Normal file
349
test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_streamk_tile_partitioner_common.hpp"
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
ck_tile::index_t expected_partials_size =
|
||||
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
|
||||
ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID;
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
|
||||
expected_partials_size + expected_flags_size);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t iter = 3;
|
||||
ck_tile::index_t tile_iter = 2;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(
|
||||
iter, tile_iter, Config::UNUSED, local_iter_dev.GetDeviceBuffer(), nullptr, Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate result
|
||||
ck_tile::index_t local_iter;
|
||||
local_iter_dev.FromDevice(&local_iter);
|
||||
EXPECT_EQ(local_iter, iter - tile_iter);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>;
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_iter = 6;
|
||||
ck_tile::index_t iter_end = 9;
|
||||
ck_tile::index_t tile_iter_end = 8;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter,
|
||||
iter_end,
|
||||
tile_iter_end,
|
||||
local_iter_end_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t local_iter_end;
|
||||
local_iter_end_dev.FromDevice(&local_iter_end);
|
||||
EXPECT_EQ(local_iter_end, tile_iter_end - tile_iter);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd)
|
||||
{
|
||||
// Types
|
||||
// Note: For this test, the Config is used for types only, the function get_locatl_iter_end is
|
||||
// static; thus, the test parameters are independent of the Config in this case.
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>;
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_iter = 12;
|
||||
ck_tile::index_t iter_end = 13;
|
||||
ck_tile::index_t tile_iter_end = 14;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter,
|
||||
iter_end,
|
||||
tile_iter_end,
|
||||
local_iter_end_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t local_iter_end;
|
||||
local_iter_end_dev.FromDevice(&local_iter_end);
|
||||
EXPECT_EQ(local_iter_end, iter_end - tile_iter);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_iter_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_idx = 1;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
tile_idx,
|
||||
tile_iter_dev.GetDeviceBuffer(),
|
||||
tile_iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_iter, tile_iter_end;
|
||||
tile_iter_dev.FromDevice(&tile_iter);
|
||||
tile_iter_end_dev.FromDevice(&tile_iter_end);
|
||||
// There are 2 iters per tile. Thus, for tile_idx 1, we expect 2 and 4 to be the start and end,
|
||||
// respectively.
|
||||
EXPECT_EQ(tile_iter, 2);
|
||||
EXPECT_EQ(tile_iter_end, 4);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t iter = 8;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(iter,
|
||||
Config::UNUSED,
|
||||
Config::UNUSED,
|
||||
tile_idx_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_idx;
|
||||
tile_idx_dev.FromDevice(&tile_idx);
|
||||
// Since there are 2 iters per tile, iter 8 maps to tile_idx 4.
|
||||
EXPECT_EQ(tile_idx, 4);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 0;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter, iter_end;
|
||||
iter_dev.FromDevice(&iter);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter, 6);
|
||||
EXPECT_EQ(iter_end, 9);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 1;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter, iter_end;
|
||||
iter_dev.FromDevice(&iter);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter, 9);
|
||||
EXPECT_EQ(iter_end, 12);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 2;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter, iter_end;
|
||||
iter_dev.FromDevice(&iter);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter, 12);
|
||||
EXPECT_EQ(iter_end, 14);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigLargerCTensor;
|
||||
ck_tile::index_t m_macro_tiles = Config::M / Config::M_TILE;
|
||||
ck_tile::index_t n_macro_tiles = Config::N / Config::N_TILE;
|
||||
ck_tile::index_t tile_idx = 0;
|
||||
|
||||
for(ck_tile::index_t row = 0; row < m_macro_tiles; ++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0; col < n_macro_tiles; ++col)
|
||||
{
|
||||
test_get_output_tile_index(tile_idx, ck_tile::make_tuple(row, col));
|
||||
++tile_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
enum StreamKTilePartitionerBaseMethodId
|
||||
{
|
||||
GET_LOCAL_ITER,
|
||||
GET_LOCAL_ITER_END,
|
||||
GET_TILE_BOUNDARIES,
|
||||
GET_TILE_INDEX,
|
||||
GET_ITER_BOUNDARIES,
|
||||
GET_OUTPUT_TILE_INDEX
|
||||
};
|
||||
|
||||
// Base kernel wrapper class to facilitate testing class device functions.
|
||||
template <typename T = ck_tile::index_t>
|
||||
struct KernelWrapper
|
||||
{
|
||||
static constexpr ck_tile::index_t kBlockSize = 1;
|
||||
|
||||
struct KernelArgs
|
||||
{
|
||||
ck_tile::index_t arg1;
|
||||
ck_tile::index_t arg2;
|
||||
ck_tile::index_t arg3;
|
||||
void* result1;
|
||||
void* result2;
|
||||
T tile_partitioner;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static KernelArgs MakeKernelArgs(ck_tile::index_t arg1,
|
||||
ck_tile::index_t arg2,
|
||||
ck_tile::index_t arg3,
|
||||
void* result1,
|
||||
void* result2,
|
||||
T tile_partitioner)
|
||||
{
|
||||
return KernelArgs{arg1, arg2, arg3, result1, result2, tile_partitioner};
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized derived class to support unique operator() functions. There is one template
|
||||
// specialization per member in the StreamKTilePartitionerBaseMethodId enum.
|
||||
template <typename TilePartitioner, StreamKTilePartitionerBaseMethodId Id>
|
||||
struct KernelWrapperSpecialized : public KernelWrapper<>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner, StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER>
|
||||
: public KernelWrapper<>
|
||||
{
|
||||
using Base = KernelWrapper<>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
TilePartitioner::get_local_iter(kargs.arg1, kargs.arg2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_BOUNDARIES>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
kargs.tile_partitioner.get_tile_boundaries(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = kargs.arg1;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = kargs.arg2;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
kargs.tile_partitioner.get_iter_boundaries(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = kargs.arg1;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = kargs.arg2;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>
|
||||
: public KernelWrapper<>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<>;
|
||||
CK_TILE_DEVICE void operator()(Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
TilePartitioner::get_local_iter_end(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner, StreamKTilePartitionerBaseMethodId::GET_TILE_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
kargs.tile_partitioner.get_tile_index(kargs.arg1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_OUTPUT_TILE_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
auto [im, in] = kargs.tile_partitioner.get_output_tile_index(kargs.arg1);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = im;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = in;
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseExpected
|
||||
{
|
||||
ck_tile::index_t sk_tiles_;
|
||||
ck_tile::index_t dp_tiles_;
|
||||
ck_tile::index_t sk_ctas_;
|
||||
ck_tile::index_t total_sk_iters_;
|
||||
ck_tile::index_t iters_per_sk_cta_;
|
||||
ck_tile::index_t iters_per_tile_;
|
||||
ck_tile::index_t extra_iters_;
|
||||
ck_tile::index_t total_dp_iters_;
|
||||
ck_tile::index_t num_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
ck_tile::index_t n_;
|
||||
};
|
||||
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_base_constructor(
|
||||
StreamKTilePartitionerBaseExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitionerBase<GemmShape>& tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_sk_tiles(), expected_values.sk_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_dp_tiles(), expected_values.dp_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_sk_ctas(), expected_values.sk_ctas_);
|
||||
EXPECT_EQ(tile_partitioner.get_total_sk_iters(), expected_values.total_sk_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_iters_per_sk_cta(), expected_values.iters_per_sk_cta_);
|
||||
EXPECT_EQ(tile_partitioner.get_extra_iters(), expected_values.extra_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_);
|
||||
EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_);
|
||||
}
|
||||
|
||||
struct StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t PLACEHOLDER = -1;
|
||||
static constexpr ck_tile::index_t UNUSED = -1;
|
||||
};
|
||||
|
||||
// Note: for the configs below, we only use BlockTiles in the TileGemmShape. We do not use
|
||||
// BlockWarps or WarpTile.
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
// This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension.
|
||||
// This facilitates testing the get_output_tile_index method.
|
||||
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 16;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
ck_tile::tuple<ck_tile::index_t, ck_tile::index_t> expected_2d_idx)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigLargerCTensor;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_OUTPUT_TILE_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_idx,
|
||||
Config::UNUSED,
|
||||
Config::UNUSED,
|
||||
im_dev.GetDeviceBuffer(),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
const auto [im_expected, in_expected] = expected_2d_idx;
|
||||
ck_tile::index_t im, in;
|
||||
im_dev.FromDevice(&im);
|
||||
in_dev.FromDevice(&in);
|
||||
EXPECT_EQ(im, im_expected);
|
||||
EXPECT_EQ(in, in_expected);
|
||||
}
|
||||
Reference in New Issue
Block a user