[CK_TILE] Stream-K Tile Engine Test Config File Generation (#3662)

* Stream-K smoke test config file generation

This change converts the stream-k smoke tests to use tile engine. Since
the m, n, and k values dependent on the CU count of a device, the
configs are generated during the Configuration Phase.

* Compute GEMM reference on GPU

* Remove redundant Stream-K tests

Removing redundant tests that are now run via tile engine.

* Fix relative and absolute tolerance calculation

This change updates the Stream-K tile engine interface to ensure that
num_wgs_per_tile is propaged and passed into the compare_results
function to calculate the rel and abs tolerance. Before, split-k was
used, which is incorrect for Stream-K since the split-k value is
always 1.

* Cleanup imports, types, and other misc items

This commit makes the following changes:
- Uses Typing module for nested type hints
- Uses quotes around cu_count_arg argument in generate_configs.cmake in
  if statements
- Adds explicit include for tuple in test_gemm_streamk_simple.cpp
- Adds a type for the tiles argument in argparser to check argument
  validity

* Use CU count as return value for better parsing

* Add reduction tests for bf16, fp8, and bf8
This commit is contained in:
Emily Martins
2026-02-03 09:12:15 -07:00
committed by GitHub
parent 3f04d27b68
commit 8cbd09c84a
22 changed files with 522 additions and 406 deletions

View File

@@ -23,19 +23,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_extended
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp
@@ -46,7 +33,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp
test_gemm_streamk_util.cpp)
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target")

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent
TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
#include "test_gemm_streamk_reduction_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,17 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent
TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent);
#include "test_gemm_streamk_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -1,88 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}

View File

@@ -1,47 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
{
ck_tile::index_t M = 256;
ck_tile::index_t N = 256;
ck_tile::index_t K = 256;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
// assumes tile sizes are large enough such that occupancy is 1.
ck_tile::index_t M = M_Tile * num_cu;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
// is 1.
ck_tile::index_t M = M_Tile * 2;
ck_tile::index_t N = N_Tile * 2;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}

View File

@@ -33,14 +33,6 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKFp16Reduction = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,

View File

@@ -1,6 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
include(generate_configs.cmake)
# ============================================================================
# GEMM Tile Engine Unit Tests
#
@@ -87,7 +89,7 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
endif()
message(STATUS " Created test target: ${target_name}")
message(DEBUG " Created test target: ${target_name}")
endfunction()
# ============================================================================
@@ -101,12 +103,12 @@ endfunction()
# layout - Matrix layout (rcr, rrr, ccr, crr)
# config_name - Configuration file name without .json extension
# ============================================================================
function(build_gemm_test_targets datatype layout config_name)
function(build_gemm_test_targets datatype layout config_name configs_dir_path)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
# Locate and validate configuration file
set(config_filename "${config_name}.json")
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
set(json_blob "${configs_dir_path}/${config_filename}")
if(NOT EXISTS ${json_blob})
message(WARNING "Test config file not found: ${json_blob}")
@@ -137,11 +139,11 @@ function(build_gemm_test_targets datatype layout config_name)
# Verify kernel list file was generated
if(NOT EXISTS ${working_path}/gemm_kernel_list.txt)
message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
return()
endif()
message(STATUS "Building tests for ${datatype}_${layout}_${config_name}")
message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}")
# STEP 2a: Extract test parameters from config
set(test_params_file "${working_path}/test_params.hpp")
@@ -230,7 +232,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# GPU architecture filtering - only build tests for supported architectures
set(GEMM_TEST_GPU_TARGETS "")
set(DESIRED_TARGETS "gfx90a;gfx942")
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
@@ -241,7 +243,7 @@ endforeach()
# Early exit if no compatible GPU architectures are available
if(NOT GEMM_TEST_GPU_TARGETS)
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
@@ -282,25 +284,35 @@ set(TEST_LAYOUTS "rcr;rrr;ccr;crr")
# Test Target Generation - Datatype-Specific Categories
# ============================================================================
# 1. SIMPLE TEST: Test for basic functionality with data types (fp16, bf16)
# These data types can use larger warp tiles due to smaller memory footprint
set(SIMPLE_TEST_CONFIG "simple_test_config")
set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json")
set(SIMPLE_DATATYPES "fp16;bf16")
# 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16)
set(SMALL_DATATYPES "fp16;bf16;fp8;bf8")
set(SIXTEEN_BIT_DATATYPES "fp16;bf16")
set(EIGHT_BIT_DATATYPES "fp8;bf8")
set(LARGE_TILES "256,256,32")
set(SMALL_TILES "128,128,32")
set(CONFIG_LIST "")
set(GENERATED_CONFIG_PATH ${CMAKE_CURRENT_BINARY_DIR}/configs)
get_cu_count(CU_COUNT)
if(EXISTS ${SIMPLE_TEST_CONFIG_FILE})
message(STATUS "Processing simple test config: ${SIMPLE_TEST_CONFIG} (fp16, bf16)")
foreach(datatype IN LISTS SIMPLE_DATATYPES)
# fp16, bf16: testing all layouts (rcr, rrr, ccr, crr)
message(STATUS "Generating and processing configs for Stream-K tests")
foreach(datatype IN LISTS SMALL_DATATYPES)
if(datatype IN_LIST SIXTEEN_BIT_DATATYPES)
generate_test_configs(${CU_COUNT} ${LARGE_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
else()
generate_test_configs(${CU_COUNT} ${SMALL_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH})
endif()
foreach(config IN LISTS CONFIG_LIST)
# testing all layouts (rcr, rrr, ccr, crr)
foreach(layout IN LISTS TEST_LAYOUTS)
build_gemm_test_targets("${datatype}" "${layout}" "${SIMPLE_TEST_CONFIG}")
build_gemm_test_targets("${datatype}" "${layout}" "${config}" "${GENERATED_CONFIG_PATH}")
endforeach()
endforeach()
else()
message(WARNING "Simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}")
endif()
endforeach()
# ============================================================================
message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:")
message(STATUS " - Simple test: fp16/bf16 (all layouts)")
message(STATUS " - Smoke tests: fp16/bf16/fp8/bf8 (all layouts)")

View File

@@ -34,17 +34,25 @@ Each test configuration can specify optimized problem sizes in its JSON file:
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
## Test Configurations
Test configs are generated during the Generation Phase. They are stored under the build directory at test/ck_tile/gemm_streamk_tile_engine/configs. The Compute Unit (CU) count of the device is required to generate the configs. If the Generation Phase occurs on a machine without a GPU or does not contain same GPU architecture on which you will run the tests, you can manually set the CU count using the `CU_COUNT` option:
```bash
# Assuming you are at the root of the repo
cd build
../script/cmake-ck-dev.sh .. gfx90a -G Ninja -DCU_COUNT=100
```
You can reference the public whitepaper for your specific GPU to get the appropriate CU count.
If no `CU_COUNT` option is given and no HIP device is found, then the default value of 100 CUs will be used to determine the problem sizes tested.
### 1. **Simple Test** (`simple_test_config.json`)
- **Purpose**: Basic functionality validation for fp16/bf16 data types
- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16
### 1. **Smoke Tests**
- **Purpose**: Basic functionality validation for fp16/bf16/fp8/bf8 data types
- **Config**: 256x256x32 (for bf16/fp16) or 128x128x32 (for bf8/fp8), warp 2x2x1, warp_tile 32x32x16
- **Traits**: compv3 pipeline only
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp16, bf16
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr)
## Data Type Support
-**fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr)
-**fp16, bf16, fp8, bf8**: Fully supported - all layouts (rcr, rrr, ccr, crr)
-**fp64**: Not supported (hardware MFMA limitation)
-**fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
-**fp32, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
## Test Result Behavior

View File

@@ -1,35 +0,0 @@
{
"problem": {
"description": "Basic functionality validation with moderate problem sizes"
},
"test_params": {
"problem_sizes": [
{"m": 256, "n": 256, "k": 128, "split_k": 1},
{"m": 512, "n": 256, "k": 256, "split_k": 1},
{"m": 256, "n": 512, "k": 256, "split_k": 1}
]
},
"tile_config": {
"tile_m": {"values": [128]},
"tile_n": {"values": [128]},
"tile_k": {"values": [64]},
"warp_m": {"values": [2]},
"warp_n": {"values": [2]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [16]},
"warp_tile_n": {"values": [16]},
"warp_tile_k": {"values": [16]}
},
"trait_config": {
"pipeline": {"values": ["compv3"]},
"epilogue": {"values": ["default"]},
"scheduler": {"values": ["intrawave"]},
"pad_m": {"values": [false]},
"pad_n": {"values": [false]},
"pad_k": {"values": [false]},
"persistent": {"values": [false, true]},
"reduction_strategy": {"values": ["atomic"]}
},
"k_block_per_cu": 1,
"permute_n": false
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <iostream>
/**
* @brief Determines whether a `hipError` is present in the given `error_status`
* @return true if the `error_status` has an error, otherwise false.
*/
bool has_error(const hipError_t& error_status)
{
if(error_status != hipSuccess)
{
std::cerr << hipGetErrorString(error_status);
return true;
}
return false;
}
/**
* @brief Returns the number of Compute Units (CUs) on the given device.
* @return The number of CUs on the device. If an error occurs while querying the device, zero is
* returned.
*/
int get_cu_count()
{
hipDevice_t dev;
hipDeviceProp_t dev_prop;
const hipError_t device_status = hipGetDevice(&dev);
if(has_error(device_status))
return 0;
const hipError_t prop_status = hipGetDeviceProperties(&dev_prop, dev);
if(has_error(prop_status))
return 0;
return dev_prop.multiProcessorCount;
}
int main() { return get_cu_count(); }

View File

@@ -0,0 +1,103 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(CU_COUNT 0 CACHE STRING "Number of Compute Units on the device")
# ============================================================================
# get_cu_count
#
# Returns the CU count for the device. If the given cu_count_arg is a positive
# integer, then the nothing happens. Otherwise, we attempt to query the CU
# count from the device. If the query is unsucessful, the default value of 100
# is returned.
#
# Parameters:
# cu_count_arg - The starting CU count
# ============================================================================
function(get_cu_count cu_count_arg)
message(STATUS "Starting query for CU count needed for Stream-K test config generation")
if(NOT "${${cu_count_arg}}" MATCHES "^[0-9]+$")
message(FATAL_ERROR "The CU count must be a non-negative integer. \
The given value of ${${cu_count_arg}} is invalid.")
endif()
if("${${cu_count_arg}}" STREQUAL "0")
set(CPP_FILE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cu_count.cpp)
set(CPP_EXE_PATH ${CMAKE_CURRENT_BINARY_DIR}/cu_count)
execute_process(
COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH}
RESULT_VARIABLE compile_result
)
if (NOT compile_result EQUAL 0)
message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n")
endif()
execute_process(
COMMAND ${CPP_EXE_PATH}
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_VARIABLE standard_error
RESULT_VARIABLE queried_cu_count
)
if (standard_error)
message(STATUS "Error information from attempting to query HIP device and properties:\n"
"${standard_error}")
endif()
# Delete the generated cu_count executable
file(REMOVE "${CPP_EXE_PATH}")
if(queried_cu_count EQUAL 0)
message(WARNING "Unable to query the number of Compute Units. \
Please use the CU_COUNT CLI option to pass in the \
number of Compute Units for your target device; otherwise, \
the default value of 100 will be used.")
set(${cu_count_arg} 100 PARENT_SCOPE)
else()
set(${cu_count_arg} ${queried_cu_count} PARENT_SCOPE)
endif()
endif()
endfunction()
# ============================================================================
# generate_test_configs
#
# Generate config json files for Stream-K tests
#
# Parameters:
# cu_count_arg - The number of CUs on the device
# tile_sizes - A list of block tile sizes: tile_m,tile_n,tile_k
# datatype - The datatype for which the config is being generated
# config_list - The variable to which the list of config file names are written
# configs_path - Path to the configs directory to which config files are written
# ============================================================================
function(generate_test_configs cu_count_arg tile_sizes datatype config_list configs_path)
message(STATUS "Generating Stream-K test config files for ${datatype}")
file(MAKE_DIRECTORY ${configs_path})
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py
--cu_count ${cu_count_arg}
--configs_dir_path ${configs_path}
--tiles ${tile_sizes}
--datatype ${datatype}
OUTPUT_VARIABLE CONFIG_LIST
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE script_ret_val
)
if (NOT script_ret_val EQUAL 0)
message(FATAL_ERROR "Eror occured during execution of ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py")
endif()
set(${config_list} ${CONFIG_LIST} PARENT_SCOPE)
endfunction()

View File

@@ -0,0 +1,277 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from enum import Enum
from typing import Dict, Tuple, List
import argparse
import json
import os
import sys
from dataclasses import dataclass, field, asdict
@dataclass
class TileConfig:
"""Represents the Tile Config section of a Tile Engine config"""
tile_m: List[int] = field(default_factory=list)
tile_n: List[int] = field(default_factory=list)
tile_k: List[int] = field(default_factory=list)
warp_m: List[int] = field(default_factory=lambda: [2])
warp_n: List[int] = field(default_factory=lambda: [2])
warp_k: List[int] = field(default_factory=lambda: [1])
warp_tile_m: List[int] = field(default_factory=lambda: [32])
warp_tile_n: List[int] = field(default_factory=lambda: [32])
warp_tile_k: List[int] = field(default_factory=lambda: [16])
def to_dict(self) -> Dict:
return {k: {"values": v} for k, v in asdict(self).items()}
@dataclass
class TraitConfig:
"""Represents the Trait Config section of a Tile Engine config"""
pipeline: List[str] = field(default_factory=lambda: ["compv3"])
epilogue: List[str] = field(default_factory=lambda: ["cshuffle"])
scheduler: List[str] = field(default_factory=lambda: ["intrawave"])
pad_m: List[bool] = field(default_factory=lambda: [False])
pad_n: List[bool] = field(default_factory=lambda: [False])
pad_k: List[bool] = field(default_factory=lambda: [False])
persistent: List[bool] = field(default_factory=lambda: [True, False])
reduction_strategy: List[str] = field(default_factory=list)
def to_dict(self) -> Dict:
return {k: {"values": v} for k, v in asdict(self).items()}
class TestVariant(Enum):
"""Represents a Stream-K test variant"""
def __init__(
self,
val: int,
reduction_strategy: List[str],
persistent: List[bool],
datatypes: List[str],
description: str,
):
self._value_ = val
self.reduction_strategy = reduction_strategy
self.persistent = persistent
self.datatypes = datatypes
self.description = description
ATOMIC_SMOKE = (
0,
["atomic"],
[True, False],
["fp16", "bf16", "fp8", "bf8"],
"Stream-K atomic smoke tests",
)
REDUCTION_SMOKE = (
2,
["reduction", "tree"],
[True, False],
["fp16", "bf16", "fp8", "bf8"],
"Stream-K reduction smoke tests",
)
EXTENDED = (
3,
["atomic"],
[True, False],
["fp16", "bf16", "fp8", "bf8"],
"Stream-K extended smoke tests",
)
def apply(self, trait_config: TraitConfig) -> None:
"""Applies the current test variant's persistent and reduction strategy setting to the given trait_config"""
trait_config.persistent = self.persistent
trait_config.reduction_strategy = self.reduction_strategy
@dataclass
class ProblemSize:
"""Represents a problem size in a Tile Engine config"""
m: int
n: int
k: int
variant: TestVariant
split_k: int = 1
def to_dict(self) -> Dict:
return {"m": self.m, "n": self.n, "k": self.k, "split_k": self.split_k}
@dataclass
class Config:
"""Represents a Tile Engine config"""
description: str
problem_sizes: list[ProblemSize] = field(default_factory=list)
tile_config: TileConfig = field(default_factory=TileConfig)
trait_config: TraitConfig = field(default_factory=TraitConfig)
k_block_per_cu: int = 1
permute_n: bool = False
def add_problem_size(self, problem: ProblemSize) -> None:
"""Adds the given problem to this config's problem_sizes"""
self.problem_sizes.append(problem)
def to_dict(self) -> Dict:
config_dict = {
"problem": {"description": f"{self.description}"},
"test_params": {
"problem_sizes": [ps.to_dict() for ps in self.problem_sizes]
},
"tile_config": self.tile_config.to_dict(),
"trait_config": self.trait_config.to_dict(),
"k_block_per_cu": self.k_block_per_cu,
"permute_n": self.permute_n,
}
return config_dict
def write_to_file(self, output_file: str) -> None:
"""Writes this configs to the given output_file in a json format"""
with open(output_file, "w") as config_file:
json.dump(self.to_dict(), config_file, indent=4)
config_file.write("\n")
def create_problem_sizes(
tile_m: int, tile_n: int, tile_k: int, cu_count: int
) -> List[ProblemSize]:
"""Creates and returns a list of problem sizes using the given arguments"""
problem_sizes = [
ProblemSize(256, 256, 256, TestVariant.ATOMIC_SMOKE),
ProblemSize(tile_m * cu_count, tile_n, tile_k, TestVariant.ATOMIC_SMOKE),
ProblemSize(
tile_m * 2, tile_n * 2, cu_count * tile_k, TestVariant.ATOMIC_SMOKE
),
ProblemSize(tile_m, tile_n, cu_count * tile_k, TestVariant.REDUCTION_SMOKE),
ProblemSize(
tile_m * 4,
tile_n,
tile_k * cu_count + (25 * tile_k),
TestVariant.REDUCTION_SMOKE,
),
ProblemSize(
tile_m * 3,
tile_n * 7,
tile_k * cu_count + (30 * tile_k),
TestVariant.REDUCTION_SMOKE,
),
# TODO: Add this test once we determine how to label tests as regresion with tile engine
# ProblemSize((tile_m * cu_count * 2) + (tile_m * 2), tile_n, 2048, TestVariant.EXTENDED)
]
return problem_sizes
def write_config_files(
problem_sizes: List[ProblemSize],
configs_dir_path: str,
datatype: str,
tile_sizes: Tuple[int, int, int],
) -> str:
"""Writes the given problem_sizes to a config file and returns the names of the config files written to"""
config_names = []
tile_m, tile_n, tile_k = tile_sizes
tile_config = TileConfig([tile_m], [tile_n], [tile_k])
# Create a config for each test variant
for variant in TestVariant:
problem_sizes_filtered = [ps for ps in problem_sizes if ps.variant == variant]
if (datatype not in variant.datatypes) or len(problem_sizes_filtered) == 0:
continue
trait_config = TraitConfig()
variant.apply(trait_config)
config_name = f"streamk_{variant.name.lower()}_tests_config_{datatype}"
config_names.append(config_name)
file_path = os.path.join(configs_dir_path, config_name + ".json")
config = Config(
variant.description, problem_sizes_filtered, tile_config, trait_config
)
config.write_to_file(file_path)
return config_names
def print_config_names(config_file_names: List[str]) -> None:
"""Prints given config file names as a single semi-colon separated string"""
print(";".join(config_file_names))
def create_config_files(
cu_count: int, configs_dir_path: str, tile_sizes: int, datatype: str
) -> None:
"""Creates Stream-K test config files and prints the file names in a semi-colon-separated list"""
tile_m, tile_n, tile_k = tile_sizes
problem_sizes = create_problem_sizes(tile_m, tile_n, tile_k, cu_count)
config_names = write_config_files(
problem_sizes, configs_dir_path, datatype, tile_sizes
)
print_config_names(config_names)
def get_args() -> Tuple[int, str, Tuple[int, int, int], str]:
"""Returns user provided arguments"""
def tile_sizes_type(val: str):
sizes = None
parts = val.split(",")
if len(parts) != 3:
raise argparse.ArgumentTypeError(
"--tiles must contain exactly three comma-separated values (m,n,k), e.g. --tiles 256,256,32"
)
try:
sizes = tuple(int(size) for size in parts)
except ValueError:
raise argparse.ArgumentTypeError(
"--tiles must contain exactly three comma-separated integers (m,n,k), e.g. --tiles 256,256,32"
)
return sizes
parser = argparse.ArgumentParser(description="Create Stream-K test configs")
parser.add_argument(
"--cu_count", required=True, help="Number of Compute Units on the device"
)
parser.add_argument(
"--configs_dir_path",
required=True,
help="Full path configs directory where config files will be written to",
)
parser.add_argument(
"--tiles",
required=True,
type=tile_sizes_type,
help="Block tile sizes for m, n, and k, respectively. Ex: --tiles 256,256,32",
)
parser.add_argument(
"--datatype",
choices=["fp16", "bf16", "fp8", "bf8"],
required=True,
help="The datatype for which the config is generated.",
)
args = parser.parse_args()
return (int(args.cu_count), args.configs_dir_path, args.tiles, args.datatype)
def main():
cu_count, configs_dir_path, tile_sizes, datatype = get_args()
create_config_files(cu_count, configs_dir_path, tile_sizes, datatype)
sys.exit(0)
if __name__ == "__main__":
main()

View File

@@ -12,6 +12,7 @@
#include <gtest/gtest.h>
#include <iostream>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
@@ -126,13 +127,18 @@ class StreamKGemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams
TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
{
// Check that kernel information is available
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << std::endl;
// Get tensor layouts from generated kernel
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
// Use split_k from test parameters
int split_k = split_k_;
// Calculate tensor strides
int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a));
int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b));
int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c));
@@ -144,27 +150,42 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::HostTensor<CDataType> c_m_n_dev_ref(
ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c)));
// Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine)
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
c_m_n_dev_ref.SetZero();
// Allocate GPU device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes());
// Copy data to device and zero output buffer
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ref_c_m_n_dev_buf.SetZero();
// Calculate reference result on host for verification
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
// Calculate reference result on device for verification
ADataType* a_m_k_dev_ref_ptr = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* b_k_n_dev_ref_ptr = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* c_m_n_dev_ref_ptr = static_cast<CDataType*>(ref_c_m_n_dev_buf.GetDeviceBuffer());
ck_tile::
reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_ref_ptr,
b_k_n_dev_ref_ptr,
c_m_n_dev_ref_ptr,
m_,
n_,
k_,
stride_a_calc,
stride_b_calc,
stride_c_calc);
ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data());
// Create GEMM kernel arguments
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
@@ -188,9 +209,10 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
1}; // rotating_count
// Launch the generated kernel (no timing overhead for fastest execution)
std::tuple<float, ck_tile::index_t> launch_result;
try
{
SelectedKernel::launch(args, stream_config);
launch_result = SelectedKernel::launch(args, stream_config);
// Kernel launched successfully if no exception thrown
}
catch(const std::exception& e)
@@ -211,22 +233,13 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality)
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
// Verify results using tile_engine's adaptive error thresholds
const ck_tile::index_t num_wgs_per_tile = get<1>(launch_result);
bool verification_passed = compare_results<ADataType, BDataType, AccDataType, CDataType>(
KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result);
KERNEL_NAME, k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_dev_ref);
EXPECT_TRUE(verification_passed) << "GEMM result verification failed";
}
TEST_P(StreamKGemmTileEngineTest, KernelInfo)
{
// Simple test to verify kernel information is available
EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty";
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_
<< std::endl;
}
// Use config-specific test parameters (included via compile flags)
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
INSTANTIATE_TEST_SUITE_P(GemmVerification,

View File

@@ -481,8 +481,9 @@ struct SelectedKernel {{
AccDataType,
TileShape,
GemmUniversalTraits>;
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
static std::tuple<float, ck_tile::index_t> launch(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream) {{
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -562,12 +563,16 @@ struct SelectedKernel {{
workspace_data.SetZero();
}}
}};
const ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
// Launch kernel
return ck_tile::launch_kernel_time_mask(
const float time = ck_tile::launch_kernel_time_mask(
stream,
reset_data_buffers,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return std::tuple<float, ck_tile::index_t>{{time, num_wgs_per_tile}};
}}
}};
"""

View File

@@ -22,25 +22,25 @@ class GemmProfiler
// Overload for single kernel benchmarking
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)> kernel_func)
std::function<std::tuple<float, ck_tile::index_t>(
const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)>>
// Create a vector with a single callable that returns name, time, and num_wgs_per_tile
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>
callables;
callables.push_back(
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
auto [time, num_wgs_per_tile] = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile);
});
benchmark(gemm_problem, callables);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
@@ -160,9 +160,9 @@ class GemmProfiler
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
const std::tuple<std::string, float, ck_tile::index_t>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
auto [name, avg_time, num_wgs_per_tile] = kernel_run_result;
auto dp_persistent =
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
@@ -196,8 +196,7 @@ class GemmProfiler
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{