mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] Enable ck_tile tests on gfx11 and gfx12 (#2821)
* [CK_TILE] Enable ck_tile test on gfx11 & gfx12
* revert an unnecessary change
* enable pk_int4 on gfx11 & gfx12
* revert .pre-commit-config.yaml
[ROCm/composable_kernel commit: b0ee317d83]
This commit is contained in:
@@ -181,15 +181,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
|
||||
@@ -314,15 +314,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
|
||||
@@ -45,15 +45,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK0PerLane = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
kABK0PerLane,
|
||||
FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane});
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
|
||||
@@ -70,9 +70,9 @@ struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
|
||||
static constexpr index_t kRepeat = 1;
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABK0PerLane = 2;
|
||||
static constexpr index_t kABK0PerLane = 1;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABK1PerLane = 4;
|
||||
static constexpr index_t kABK1PerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
|
||||
@@ -18,7 +18,7 @@ function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX)
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
create_tile_add_rmsnorm2d_rdquant_fwd("fp16")
|
||||
create_tile_add_rmsnorm2d_rdquant_fwd("bf16")
|
||||
else()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
|
||||
endif()
|
||||
|
||||
@@ -27,21 +27,41 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmWarpConfig_Wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename GemmWarpConfig, typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
@@ -255,9 +275,13 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchCount};
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
|
||||
ck_tile::stream_config{nullptr, false});
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_batched_gemm<GemmWarpConfig_Wmma, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, false});
|
||||
#else
|
||||
invoke_batched_gemm<GemmWarpConfig_Mfma, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, false});
|
||||
#endif
|
||||
|
||||
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_ck_tile_batched_transpose test_batched_transpose.cpp)
|
||||
set_property(TARGET test_ck_tile_batched_transpose PROPERTY CXX_STANDARD 20)
|
||||
else()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
|
||||
|
||||
@@ -106,7 +106,7 @@ class TestCkTileElementwise : public ::testing::Test
|
||||
ck_tile::index_t grid_size =
|
||||
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block(TestElementWiseShape::kBlockSize, 1, 1);
|
||||
dim3 block = dim3(ew_kernel.BlockSize());
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
|
||||
|
||||
@@ -12,16 +12,16 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp)
|
||||
@@ -30,37 +30,47 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker)
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target ")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11|gfx12")
|
||||
# On Radeon devices, build the WMMA version instead
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline")
|
||||
endif()
|
||||
|
||||
@@ -13,6 +13,28 @@
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
|
||||
struct GemmConfig_Mfma : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmConfig_Wmma : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -38,17 +60,17 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -130,7 +152,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
@@ -142,12 +167,12 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -160,22 +185,22 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -185,7 +210,7 @@ bool run_gemm_test_prec_type(std::string a_layout,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType, typename CPrecType>
|
||||
template <typename GemmConfig, typename APrecType, typename BPrecType, typename CPrecType>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -195,7 +220,8 @@ bool run_gemm_test(int argc, char* argv[])
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
return run_gemm_test_prec_type<APrecType, BPrecType, CPrecType>(a_layout, b_layout, arg_parser);
|
||||
return run_gemm_test_prec_type<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
@@ -255,8 +281,15 @@ int run_gemm_combinations()
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<APrecType, BPrecType, CPrecType>(ARG_COUNT, argv) &&
|
||||
#if CK_TILE_USE_WMMA
|
||||
is_success = run_gemm_test<GemmConfig_Wmma, APrecType, BPrecType, CPrecType>(
|
||||
ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#else
|
||||
is_success = run_gemm_test<GemmConfig_Mfma, APrecType, BPrecType, CPrecType>(
|
||||
ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#endif
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
|
||||
@@ -220,6 +220,27 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
|
||||
@@ -325,6 +325,13 @@ int run_gemm_combinations()
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
is_success = run_gemm_test<GemmConfigComputeV3_WMMA<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#else
|
||||
is_success = run_gemm_test<GemmConfigComputeV3<CPrecType>,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
@@ -335,6 +342,7 @@ int run_gemm_combinations()
|
||||
BPrecType,
|
||||
CPrecType>(ARG_COUNT, argv) &&
|
||||
is_success;
|
||||
#endif
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp)
|
||||
add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp)
|
||||
target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -86,7 +86,28 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
template <typename ADataType,
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmWarpConfig_Wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename GemmWarpConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
@@ -99,17 +120,17 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
void invoke_gemm_multi_d(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
@@ -359,8 +380,9 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideE});
|
||||
|
||||
invoke_gemm_multi_d<ADataType,
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_gemm_multi_d<GemmWarpConfig_Wmma,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
@@ -370,6 +392,19 @@ class TestCkTileGemmMultiD : public ::testing::Test
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
|
||||
#else
|
||||
invoke_gemm_multi_d<GemmWarpConfig_Mfma,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
|
||||
#endif
|
||||
|
||||
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE
|
||||
|
||||
@@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_wp test_gemm_pipeline_wp.cpp)
|
||||
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_wp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -31,8 +31,10 @@ using F8Types = std::tuple<Row, Col, Row, F8, F8, F32, F16, Default, WeightPresh
|
||||
|
||||
using KernelTypesWeightPreshuffle = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>,
|
||||
F8Types
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
, F8Types
|
||||
#endif
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -63,6 +63,23 @@ struct config
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
struct config_wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(Datatype);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
@@ -79,13 +96,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using GemmConfig = config<ADataType>;
|
||||
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
@@ -253,6 +269,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
k_batches_ = {1};
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
template <bool PadM = true, bool PadN = true, bool PadK = true, bool Preshuffle = false>
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
@@ -263,11 +321,17 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
#if CK_TILE_USE_WMMA
|
||||
RunSingle<config_wmma<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
#else
|
||||
RunSingle<config<ADataType>, PadM, PadN, PadK, Preshuffle>(
|
||||
M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
@@ -327,16 +391,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<BDataType> t_view({N / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
K / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
|
||||
std::copy(b_k_n.begin(), b_k_n.end(), t_view.begin());
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
@@ -354,7 +409,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});
|
||||
invoke_gemm<GemmConfig, PadM, PadN, PadK, Preshuffle>(
|
||||
args, ck_tile::stream_config{nullptr, false});
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
|
||||
endif()
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using PersistentType = std::tuple_element_t<7, Tuple>;
|
||||
static constexpr bool Persistent = PersistentType::value;
|
||||
|
||||
struct GroupedGemKernelParam
|
||||
struct GroupedGemKernelParam_Mfma
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
@@ -51,13 +51,24 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
|
||||
{
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
@@ -200,7 +211,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
@@ -460,15 +471,27 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
#else
|
||||
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
#else
|
||||
invoke_grouped_gemm<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp)
|
||||
endif()
|
||||
|
||||
@@ -14,7 +14,7 @@ function(create_tile_layernorm2d_fwd SUFFIX)
|
||||
target_compile_options(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC)
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Currently ck_tile is only built on gfx90a, gfx942 and gfx950
|
||||
if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a")
|
||||
# Currently ck_tile is only built on gfx90a, gfx942, gfx950, gfx11 and gfx12
|
||||
if(GPU_TARGETS MATCHES "gfx942|gfx950|gfx90a|gfx11|gfx12")
|
||||
|
||||
function(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
|
||||
add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
|
||||
function(add_permute_test TARGET_NAME MAIN_SRC)
|
||||
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if !CK_TILE_USE_WMMA
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
#include "alternative_impl/matrix_core_swizzle.hpp"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template <int bytes>
|
||||
@@ -193,6 +195,7 @@ class TestCkTilePermute : public ::testing::Test
|
||||
|
||||
return permute<DataType>(a, stream_config);
|
||||
};
|
||||
#if !CK_TILE_USE_WMMA
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
if((perm == std::string("0,1,4,2,5,3,6") || perm == std::string("0,1,2,4,5,3,6") ||
|
||||
@@ -278,6 +281,7 @@ class TestCkTilePermute : public ::testing::Test
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#endif
|
||||
{
|
||||
run_permute();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_reduce2d PRIVATE utility)
|
||||
|
||||
@@ -59,7 +59,7 @@ class TestCkTileReduce : public ::testing::Test
|
||||
using Kernel = ck_tile::Reduce<Problem>;
|
||||
|
||||
// Launch configuration
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::index_t kGridSize =
|
||||
|
||||
@@ -14,7 +14,7 @@ function(create_tile_rmsnorm2d_fwd SUFFIX)
|
||||
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
function (add_smoothquant_test TARGET_NAME MAIN_SRC)
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ function(add_tile_topk_softmax_test SUFFIX)
|
||||
target_compile_options(${TEST_NAME} PRIVATE ${TEST_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_tile_topk_softmax_test(fp16)
|
||||
add_tile_topk_softmax_test(bf16)
|
||||
else()
|
||||
|
||||
Reference in New Issue
Block a user