[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)

* [CK_TILE] Port hw independent changes from internal repo to develop branch

It includes PR#96, #114, #120, #121.

* correct rebase error

[ROCm/composable_kernel commit: fc7bf0ab1c]
This commit is contained in:
linqunAMD
2025-12-13 01:28:37 +08:00
committed by GitHub
parent 1f2421c944
commit c6ab08a491
22 changed files with 465 additions and 310 deletions

View File

@@ -130,7 +130,7 @@ auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
constexpr index_t kMPerBlock = Problem::kMPerBlock;
constexpr index_t kNPerBlock = Problem::kNPerBlock;
constexpr index_t kBlockSize = Problem::kBlockSize;
index_t kBlockSize = ck_tile::is_wave32() ? Problem::kBlockSize / 2 : Problem::kBlockSize;
std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N
<< ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock

View File

@@ -7,7 +7,7 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp)
add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp)
target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -20,20 +20,21 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue enabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
#endif
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
>;
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);

View File

@@ -20,17 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue disabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
#endif
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
>;
// clang-format on

View File

@@ -23,6 +23,28 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
#endif
}
template <typename A0DataType,
typename B0DataType,
typename AccDataType,
@@ -103,17 +125,25 @@ class TestCkTileGemmMultiABD : public ::testing::Test
DsDataType::size()>& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
#if CK_TILE_USE_WMMA
using ADataType =
ck_tile::remove_cvref_t<std::tuple_element_t<ck_tile::number<0>{}, AsDataType>>;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<ADataType, N_Warp_Tile>();
#else
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
constexpr bool DoubleSmemBuffer = false;

View File

@@ -13,6 +13,28 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
@@ -80,7 +102,7 @@ struct config_wmma
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
};
template <typename Tuple>

View File

@@ -9,7 +9,7 @@ endif()
# Use standard asm for rtn bf16 conversion instead of turncate
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
if(GPU_TARGETS MATCHES "gfx94|gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp)
target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -29,9 +29,6 @@ template <typename ALayout_,
int M_Warp_val_,
int N_Warp_val_,
int K_Warp_val_,
int M_Warp_Tile_val_,
int N_Warp_Tile_val_,
int K_Warp_Tile_val_,
bool DoubleSmemBuffer_val_,
ck_tile::GemmPipelineScheduler Scheduler_val_,
PipelineType Pipeline_val_,
@@ -50,15 +47,21 @@ struct KernelConfig
using EDataType = EDataType_;
using DsDataType = ck_tile::tuple<D0DataType_, D1DataType_>;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
static constexpr int K_Tile_ = K_Tile_val_;
static constexpr int M_Warp_ = M_Warp_val_;
static constexpr int N_Warp_ = N_Warp_val_;
static constexpr int K_Warp_ = K_Warp_val_;
static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_;
static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_;
static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
static constexpr int K_Tile_ = K_Tile_val_;
static constexpr int M_Warp_ = M_Warp_val_;
static constexpr int N_Warp_ = N_Warp_val_;
static constexpr int K_Warp_ = K_Warp_val_;
#if CK_TILE_USE_WMMA
static constexpr int M_Warp_Tile_ = 16;
static constexpr int N_Warp_Tile_ = 16;
static constexpr int K_Warp_Tile_ = 16;
#else
static constexpr int M_Warp_Tile_ = 32;
static constexpr int N_Warp_Tile_ = 32;
static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8;
#endif
static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_;
static constexpr auto Scheduler_ = Scheduler_val_;
static constexpr PipelineType Pipeline_ = Pipeline_val_;
@@ -68,21 +71,21 @@ struct KernelConfig
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
// FP16 A/B/D/E
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
// BF16 A/B/D/E
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
>;
// clang-format on

View File

@@ -6,7 +6,7 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp)
target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -50,16 +50,16 @@ struct KernelConfig
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>,
#endif
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, BF16, BF16, F32, BF16, False, 128, 128, 128, 2>,

View File

@@ -14,6 +14,9 @@
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
@@ -25,6 +28,7 @@ constexpr ck_tile::index_t get_k_warp_tile_flatmm()
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
#endif
}
template <typename Tuple>
@@ -101,13 +105,40 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
N_Warp_Tile,
k_ / K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename ALayout, typename BLayout, typename CLayout>
@@ -115,6 +146,11 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t WaveSize = 32;
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
constexpr bool SupportVectorSize16 =
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -137,7 +173,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
/*UseStructuredSparsity*/ false,
/*Persistent*/ false,
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
/*Preshuffle*/ true,
VectorSize>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -210,6 +247,12 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t WaveSize = 32;
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile);
constexpr bool SupportVectorSize16 =
(M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0;
constexpr int VectorSize = SupportVectorSize16 ? 16 : 8;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
@@ -230,7 +273,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
/*UseStructuredSparsity*/ false,
/*Persistent*/ true, // Enable persistent mode
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
/*Preshuffle*/ true,
VectorSize>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,