mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)
* [CK_TILE] Port hw independent changes from internal repo to develop branch It includes PR#96, #114, #120, #121. * correct rebase error
This commit is contained in:
@@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -50,16 +50,16 @@ struct KernelConfig
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
|
||||
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>,
|
||||
|
||||
#endif
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>,
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
@@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
@@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t WaveSize = 32;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
|
||||
constexpr bool SupportVectorSize16 =
|
||||
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
|
||||
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ false,
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
/*Preshuffle*/ true,
|
||||
VectorSize>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
@@ -210,6 +247,12 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t WaveSize = 32;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
|
||||
constexpr bool SupportVectorSize16 =
|
||||
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
|
||||
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
@@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ true, // Enable persistent mode
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
/*Preshuffle*/ true,
|
||||
VectorSize>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
|
||||
Reference in New Issue
Block a user