mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Merge commit 'b0ee317d83b77741022997265d4125697e7f7804' into develop
This commit is contained in:
@@ -12,16 +12,16 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp)
|
||||
@@ -30,37 +30,47 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker)
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target ")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11|gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline")
|
||||
endif()
|
||||
|
||||
@@ -13,6 +13,28 @@
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
|
||||
struct GemmConfig_Mfma : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmConfig_Wmma : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -38,17 +60,17 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -130,7 +152,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
@@ -142,12 +167,12 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -160,22 +185,22 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -185,7 +210,7 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType, typename CPrecType>
|
||||
template <typename GemmConfig, typename APrecType, typename BPrecType, typename CPrecType>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -195,7 +220,8 @@ bool run_gemm_test(int argc, char* argv[])
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
return run_gemm_test_prec_type<APrecType, BPrecType, CPrecType>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_test_prec_type<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
@@ -255,8 +281,15 @@ int run_gemm_combinations()
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<APrecType, BPrecType, CPrecType>(ARG_COUNT, argv) &&
|
||||
#if CK_TILE_USE_WMMA
|
||||
is_success = run_gemm_test<GemmConfig_Wmma, APrecType, BPrecType, CPrecType>(
|
||||
ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#else
|
||||
is_success = run_gemm_test<GemmConfig_Mfma, APrecType, BPrecType, CPrecType>(
|
||||
ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#endif
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
|
||||
@@ -220,6 +220,27 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
|
||||
@@ -325,6 +325,13 @@ int run_gemm_combinations()
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
is_success = run_gemm_test<GemmConfigComputeV3_WMMA<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#else
|
||||
is_success = run_gemm_test<GemmConfigComputeV3<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
@@ -335,6 +342,7 @@ int run_gemm_combinations()
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#endif
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user