mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Stream-K Gemm Example for fp8 and bf8 (#3041)
* Addition of streamk fp8 example for CK Tile * Adding in bf8 streamk example in CK Tile * Refactoring fp8/bf8 unit tests Refactored the unit tests for fp8/bf8 to utilize the test harness. Implemented smoke tests with layouts: CCR, CRR, RCR, RRR for fp8/bf8. The tests are using 128x128x32 for the tile configuration, as other configurations revealed implementation gaps that are currently being documented.
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -28,10 +28,10 @@ args:
|
||||
-stride_b tensor B stride (default:0)
|
||||
-stride_c tensor C stride (default:0)
|
||||
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache flush the cache before running the kernel (default:true)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -75,6 +75,18 @@ struct DataTypeTraits<ck_tile::bf16_t>
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -94,7 +106,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -56,7 +56,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -187,6 +187,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -28,3 +28,4 @@ add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-mllvm
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
# Currently test_ck_tile_streamk is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
@@ -6,23 +20,33 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
)
|
||||
# TODO: enable extended tests after tolerances for atomic reductions are addressed.
|
||||
# add_gtest_executable(test_ck_tile_streamk_extended
|
||||
@@ -129,6 +153,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF8_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F8_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F8_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F8_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F8_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -14,6 +14,8 @@
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -79,7 +81,7 @@ struct Layouts
|
||||
// MxNxK MxNxK M N K M N K
|
||||
//
|
||||
// The example options for each field are:
|
||||
// - DATA_TYPE: F16, BF16
|
||||
// - DATA_TYPE: F16, BF16, F8, BF8
|
||||
// - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC
|
||||
// - PIPELINE_TYPE: Mem, CompV3, CompV4
|
||||
// - M_MACRO_TILE: 128, 256, etc
|
||||
@@ -121,3 +123,5 @@ struct Layouts
|
||||
|
||||
#include "test_gemm_streamk_types_fp16.hpp"
|
||||
#include "test_gemm_streamk_types_bf16.hpp"
|
||||
#include "test_gemm_streamk_types_fp8.hpp"
|
||||
#include "test_gemm_streamk_types_bf8.hpp"
|
||||
|
||||
77
test/ck_tile/gemm_streamk/test_gemm_streamk_types_bf8.hpp
Normal file
77
test/ck_tile/gemm_streamk/test_gemm_streamk_types_bf8.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_streamk_types.hpp"
|
||||
|
||||
template <typename M_MacroTile,
|
||||
typename N_MacroTile,
|
||||
typename K_MacroTile,
|
||||
typename M_Warps,
|
||||
typename N_Warps,
|
||||
typename K_Warps,
|
||||
typename M_MmaTile,
|
||||
typename N_MmaTile,
|
||||
typename K_MmaTile,
|
||||
typename PipelineType,
|
||||
typename Persistent>
|
||||
struct BF8Layouts
|
||||
{
|
||||
// clang-format off
|
||||
// For CDNA, we support [A, B, Acc, C] = [bf8, bf8, f32, f16] and [bf8, bf8, f32, f32]:
|
||||
using BF8_BF8_F32_F16 = Layouts<BF8, BF8, F32, F16, M_MacroTile, N_MacroTile, K_MacroTile, M_Warps, N_Warps, K_Warps, M_MmaTile, N_MmaTile, K_MmaTile, PipelineType, Persistent>;
|
||||
using BF8_BF8_F32_F32 = Layouts<BF8, BF8, F32, F32, M_MacroTile, N_MacroTile, K_MacroTile, M_Warps, N_Warps, K_Warps, M_MmaTile, N_MmaTile, K_MmaTile, PipelineType, Persistent>;
|
||||
using RRR = detail::combine_t<typename BF8_BF8_F32_F16::RRR, typename BF8_BF8_F32_F32::RRR>;
|
||||
using RRC = detail::combine_t<typename BF8_BF8_F32_F16::RRC, typename BF8_BF8_F32_F32::RRC>;
|
||||
using RCR = detail::combine_t<typename BF8_BF8_F32_F16::RCR, typename BF8_BF8_F32_F32::RCR>;
|
||||
using RCC = detail::combine_t<typename BF8_BF8_F32_F16::RCC, typename BF8_BF8_F32_F32::RCC>;
|
||||
using CRR = detail::combine_t<typename BF8_BF8_F32_F16::CRR, typename BF8_BF8_F32_F32::CRR>;
|
||||
using CRC = detail::combine_t<typename BF8_BF8_F32_F16::CRC, typename BF8_BF8_F32_F32::CRC>;
|
||||
using CCR = detail::combine_t<typename BF8_BF8_F32_F16::CCR, typename BF8_BF8_F32_F32::CCR>;
|
||||
using CCC = detail::combine_t<typename BF8_BF8_F32_F16::CCC, typename BF8_BF8_F32_F32::CCC>;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
|
||||
// Macro to declare all layout combinations for BF8 data type
|
||||
#define DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) \
|
||||
DECLARE_PARAMS_ALL_LAYOUTS(BF8Layouts, BF8, PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT)
|
||||
|
||||
// Macro to declare all layout combinations for BF8 data type and a variety of sizes
|
||||
#define DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(PIPELINE_TYPE, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 128, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT)
|
||||
|
||||
// Declare all BF8 parameter sets for different pipeline types and persistence options
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(Mem, NonPersistent)
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV3, NonPersistent)
|
||||
DECLARE_BF8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV4, NonPersistent)
|
||||
|
||||
// Here, we have a combination of parameter set symbols that we can use to compile into test cases
|
||||
// __________________________________________________
|
||||
// | Parameter Name |
|
||||
// using BF8_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent = ...
|
||||
// / | \ \ \ \ \
|
||||
// DATA LAYOUT PIPELINE MACRO WARPS MMA PERSISTENT
|
||||
// TYPE TYPE TILE MxNxK TILE TYPE
|
||||
// MxNxK MxNxK
|
||||
//
|
||||
// The options for each field are:
|
||||
// - DATA TYPE: BF8
|
||||
// - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC
|
||||
// - PIPELINE_TYPE: Mem, CompV3, CompV4
|
||||
// - Macro Tile: 128x128x32, 128x128x64, 128x128x128, 256x128x32, 256x128x64, 128x256x32, 128x256x64, 256x256x32, 256x256x64
|
||||
// - Warps: 2x2x1
|
||||
// - MMA Tile: 32x32x16
|
||||
// - PERSISTENT_TYPE: NonPersistent
|
||||
|
||||
// clang-format on
|
||||
77
test/ck_tile/gemm_streamk/test_gemm_streamk_types_fp8.hpp
Normal file
77
test/ck_tile/gemm_streamk/test_gemm_streamk_types_fp8.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_streamk_types.hpp"
|
||||
|
||||
template <typename M_MacroTile,
|
||||
typename N_MacroTile,
|
||||
typename K_MacroTile,
|
||||
typename M_Warps,
|
||||
typename N_Warps,
|
||||
typename K_Warps,
|
||||
typename M_MmaTile,
|
||||
typename N_MmaTile,
|
||||
typename K_MmaTile,
|
||||
typename PipelineType,
|
||||
typename Persistent>
|
||||
struct F8Layouts
|
||||
{
|
||||
// clang-format off
|
||||
// For CDNA, we support [A, B, Acc, C] = [f8, f8, f32, f16] and [f8, f8, f32, f32]:
|
||||
using F8_F8_F32_F16 = Layouts<F8, F8, F32, F16, M_MacroTile, N_MacroTile, K_MacroTile, M_Warps, N_Warps, K_Warps, M_MmaTile, N_MmaTile, K_MmaTile, PipelineType, Persistent>;
|
||||
using F8_F8_F32_F32 = Layouts<F8, F8, F32, F32, M_MacroTile, N_MacroTile, K_MacroTile, M_Warps, N_Warps, K_Warps, M_MmaTile, N_MmaTile, K_MmaTile, PipelineType, Persistent>;
|
||||
using RRR = detail::combine_t<typename F8_F8_F32_F16::RRR, typename F8_F8_F32_F32::RRR>;
|
||||
using RRC = detail::combine_t<typename F8_F8_F32_F16::RRC, typename F8_F8_F32_F32::RRC>;
|
||||
using RCR = detail::combine_t<typename F8_F8_F32_F16::RCR, typename F8_F8_F32_F32::RCR>;
|
||||
using RCC = detail::combine_t<typename F8_F8_F32_F16::RCC, typename F8_F8_F32_F32::RCC>;
|
||||
using CRR = detail::combine_t<typename F8_F8_F32_F16::CRR, typename F8_F8_F32_F32::CRR>;
|
||||
using CRC = detail::combine_t<typename F8_F8_F32_F16::CRC, typename F8_F8_F32_F32::CRC>;
|
||||
using CCR = detail::combine_t<typename F8_F8_F32_F16::CCR, typename F8_F8_F32_F32::CCR>;
|
||||
using CCC = detail::combine_t<typename F8_F8_F32_F16::CCC, typename F8_F8_F32_F32::CCC>;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
|
||||
// Macro to declare all layout combinations for FP8 data type
|
||||
#define DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT) \
|
||||
DECLARE_PARAMS_ALL_LAYOUTS(F8Layouts, F8, PIPELINE_TYPE, M_MACRO_TILE, N_MACRO_TILE, K_MACRO_TILE, M_WARPS, N_WARPS, K_WARPS, M_MMA_TILE, N_MMA_TILE, K_MMA_TILE, PERSISTENT)
|
||||
|
||||
// Macro to declare all layout combinations for FP8 data type and a variety of sizes
|
||||
#define DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(PIPELINE_TYPE, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 128, 128, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 128, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 128, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 32, 2, 2, 1, 32, 32, 16, PERSISTENT) \
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS(PIPELINE_TYPE, 256, 256, 64, 2, 2, 1, 32, 32, 16, PERSISTENT)
|
||||
|
||||
// Declare all FP8 parameter sets for different pipeline types and persistence options
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(Mem, NonPersistent)
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV3, NonPersistent)
|
||||
DECLARE_F8_PARAMS_ALL_LAYOUTS_ALL_SIZES(CompV4, NonPersistent)
|
||||
|
||||
// Here, we have a combination of parameter set symbols that we can use to compile into test cases
|
||||
// __________________________________________________
|
||||
// | Parameter Name |
|
||||
// using F8_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent = ...
|
||||
// / | \ \ \ \ \
|
||||
// DATA LAYOUT PIPELINE MACRO WARPS MMA PERSISTENT
|
||||
// TYPE TYPE TILE MxNxK TILE TYPE
|
||||
// MxNxK MxNxK
|
||||
//
|
||||
// The options for each field are:
|
||||
// - DATA TYPE: F8
|
||||
// - LAYOUT: RRR, RRC, RCR, RCC, CRR, CRC, CCR, CCC
|
||||
// - PIPELINE_TYPE: Mem, CompV3, CompV4
|
||||
// - Macro Tile: 128x128x32, 128x128x64, 128x128x128, 256x128x32, 256x128x64, 128x256x32, 128x256x64, 256x256x32, 256x256x64
|
||||
// - Warps: 2x2x1
|
||||
// - MMA Tile: 32x32x16
|
||||
// - PERSISTENT_TYPE: NonPersistent
|
||||
|
||||
// clang-format on
|
||||
Reference in New Issue
Block a user