diff --git a/example/ck_tile/40_streamk_gemm/CMakeLists.txt b/example/ck_tile/40_streamk_gemm/CMakeLists.txt index 3539dee05b..3b285a54b5 100644 --- a/example/ck_tile/40_streamk_gemm/CMakeLists.txt +++ b/example/ck_tile/40_streamk_gemm/CMakeLists.txt @@ -1,5 +1,10 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp) + set(EXAMPLE_GEMM_COMPILE_OPTIONS) + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile streamk gemm tests for current target") endif() diff --git a/example/ck_tile/40_streamk_gemm/README.md b/example/ck_tile/40_streamk_gemm/README.md index d2ff7eabc0..fe9eb0c4f8 100644 --- a/example/ck_tile/40_streamk_gemm/README.md +++ b/example/ck_tile/40_streamk_gemm/README.md @@ -28,10 +28,10 @@ args: -stride_b tensor B stride (default:0) -stride_c tensor C stride (default:0) -v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) - -prec data type. fp16/bf16 (default:fp16) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -warmup number of iterations before benchmarking the kernel (default:50) -repeat number of iterations to benchmark the kernel (default:100) -timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu) -init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0) -flush_cache flush the cache before running the kernel (default:true) -``` \ No newline at end of file +``` diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index e698539eea..abcca7eaec 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -75,6 +75,18 @@ struct DataTypeTraits static constexpr const char* name = "bf16"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -94,7 +106,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmarking the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 40709e38e2..8ec409023d 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -56,7 +56,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmUniversalTraits, GemmConfig::Scheduler>; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, TypeConfig>( a_layout, b_layout, argc, argv); } + else if(data_type == "fp8") + { + using TypeConfig = StreamKGemmTypeConfig; + return run_gemm_example_prec_type, TypeConfig>( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = StreamKGemmTypeConfig; + return run_gemm_example_prec_type, TypeConfig>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 5e178e3669..a6cfcde86e 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -28,3 +28,4 @@ add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) + diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index eba411e271..150181d0d7 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -1,3 +1,17 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS + -mllvm + -enable-noalias-to-md-conversion=0 +) +set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + # Currently test_ck_tile_streamk is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9") @@ -6,23 +20,33 @@ if(GPU_TARGETS MATCHES "gfx9") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_smoke - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp #${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp ) # TODO: enable extended tests after tolerances for atomic reductions are addressed. # add_gtest_executable(test_ck_tile_streamk_extended @@ -129,6 +153,7 @@ if(GPU_TARGETS MATCHES "gfx9") ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp test_gemm_streamk_reboot_util.cpp) + target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping test_ck_tile_streamk tests for current target") endif() diff --git a/test/ck_tile/gemm_streamk/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..f2406d652d --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS BF8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..e961d7c35b --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS BF8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..93b7c57b3a --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS BF8_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..64fa12d226 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS BF8_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..cdf4e8306e --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS F8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..5edbec9c45 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS F8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..4e3a9eaa25 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS F8_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp new file mode 100644 index 0000000000..ab2eabb442 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +#define TEST_SUITE_PARAMS F8_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent +#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS) + +DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS); + +#include "test_gemm_streamk_cases.inc" diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index 578eb31189..73e44d5cfd 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -14,6 +14,8 @@ using F16 = ck_tile::half_t; using F32 = float; using BF16 = ck_tile::bf16_t; +using F8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -79,7 +81,7 @@ struct Layouts // MxNxK MxNxK M N K M N K // // The example options for each field are: -// - DATA_TYPE: F16, BF16 +// - DATA_TYPE: F16, BF16, F8, BF8 // - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC // - PIPELINE_TYPE: Mem, CompV3, CompV4 // - M_MACRO_TILE: 128, 256, etc @@ -121,3 +123,5 @@ struct Layouts #include "test_gemm_streamk_types_fp16.hpp" #include "test_gemm_streamk_types_bf16.hpp" +#include "test_gemm_streamk_types_fp8.hpp" +#include "test_gemm_streamk_types_bf8.hpp" diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types_bf8.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types_bf8.hpp new file mode 100644 index 0000000000..47f64e35ad --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types_bf8.hpp @@ -0,0 +1,77 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "test_gemm_streamk_types.hpp" + +template +struct BF8Layouts +{ + // clang-format off + // For CDNA, we support [A, B, Acc, C] = [bf8, bf8, f32, f16] and [bf8, bf8, f32, f32]: + using BF8_BF8_F32_F16 = Layouts; + using BF8_BF8_F32_F32 = Layouts; + using RRR = detail::combine_t; + using RRC = detail::combine_t; + using RCR = detail::combine_t; + using RCC = detail::combine_t; + using CRR = detail::combine_t; + using CRC = detail::combine_t; + using CCR = detail::combine_t; + using CCC = detail::combine_t; + // clang-format on +}; + +// clang-format off + +// Macro to declare all layout combinations for BF8 data type +#define DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) \ + DECLARE_PARAMS_ALL_LAYOUTS(BF8Layouts, BF8, PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) + +// Macro to declare all layout combinations for BF8 data type and a variety of sizes +#define DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(PIPELINE_TYPE, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 128, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) + +// Declare all BF8 parameter sets for different pipeline types and persistence options +DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(Mem, NonPersistent) +DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV3, NonPersistent) +DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV4, NonPersistent) + +// Here, we have a combination of parameter set symbols that we can use to compile into test cases +// __________________________________________________ +// | Parameter Name | +// using BF8_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent = ... +// / | \ \ \ \ \ +// DATA LAYOUT PIPELINE MACRO WARPS MMA PERSISTENT +// TYPE TYPE TILE MxNxK TILE TYPE +// MxNxK MxNxK +// +// The options for each field are: +// - DATA TYPE: BF8 +// - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC +// - PIPELINE_TYPE: Mem, CompV3, CompV4 +// - Macro Tile: 128x128x32, 128x128x64, 128x128x128, 256x128x32, 256x128x64, 128x256x32, 128x256x64, 256x256x32, 256x256x64 +// - Warps: 2x2x1 +// - MMA Tile: 32x32x16 +// - PERSISTENT_TYPE: NonPersistent + +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types_fp8.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types_fp8.hpp new file mode 100644 index 0000000000..30132e6b6d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types_fp8.hpp @@ -0,0 +1,77 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "test_gemm_streamk_types.hpp" + +template +struct F8Layouts +{ + // clang-format off + // For CDNA, we support [A, B, Acc, C] = [f8, f8, f32, f16] and [f8, f8, f32, f32]: + using F8_F8_F32_F16 = Layouts; + using F8_F8_F32_F32 = Layouts; + using RRR = detail::combine_t; + using RRC = detail::combine_t; + using RCR = detail::combine_t; + using RCC = detail::combine_t; + using CRR = detail::combine_t; + using CRC = detail::combine_t; + using CCR = detail::combine_t; + using CCC = detail::combine_t; + // clang-format on +}; + +// clang-format off + +// Macro to declare all layout combinations for FP8 data type +#define DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) \ + DECLARE_PARAMS_ALL_LAYOUTS(F8Layouts, F8, PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) + +// Macro to declare all layout combinations for FP8 data type and a variety of sizes +#define DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(PIPELINE_TYPE, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 128, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \ + DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) + +// Declare all FP8 parameter sets for different pipeline types and persistence options +DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(Mem, NonPersistent) +DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV3, NonPersistent) +DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV4, NonPersistent) + +// Here, we have a combination of parameter set symbols that we can use to compile into test cases +// __________________________________________________ +// | Parameter Name | +// using F8_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent = ... +// / | \ \ \ \ \ +// DATA LAYOUT PIPELINE MACRO WARPS MMA PERSISTENT +// TYPE TYPE TILE MxNxK TILE TYPE +// MxNxK MxNxK +// +// The options for each field are: +// - DATA TYPE: F8 +// - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC +// - PIPELINE_TYPE: Mem, CompV3, CompV4 +// - Macro Tile: 128x128x32, 128x128x64, 128x128x128, 256x128x32, 256x128x64, 128x256x32, 128x256x64, 256x256x32, 256x256x64 +// - Warps: 2x2x1 +// - MMA Tile: 32x32x16 +// - PERSISTENT_TYPE: NonPersistent + +// clang-format on