mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5082 (commit 9313659)
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Summary
- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
- `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
- `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
- `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
- `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results
## Changes
**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files
**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`
This commit is contained in:
committed by
assistant-librarian[bot]
parent
2169367735
commit
1a4aa7fd89
@@ -43,7 +43,6 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleA scale_a,
|
||||
ScaleB scale_b,
|
||||
int n_warmup,
|
||||
@@ -55,7 +54,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -90,8 +89,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -100,29 +99,24 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
[&](auto has_hot_loop_, auto tail_num_) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
split_k_.value,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
false,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
@@ -166,7 +160,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
|
||||
@@ -174,45 +167,6 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NLane, typename dtype>
|
||||
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack =
|
||||
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
@@ -70,6 +70,47 @@ struct MXFlatmmArchTraits
|
||||
|
||||
static constexpr int GetNLane() { return Config::N_Warp_Tile; }
|
||||
|
||||
template <typename dtype>
|
||||
static auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
constexpr ck_tile::index_t NLane = Config::N_Warp_Tile;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = std::is_same_v<dtype, ck_tile::pk_fp6x16_t>
|
||||
? 32
|
||||
: 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <bool KLast, typename dtype>
|
||||
static auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
|
||||
@@ -32,7 +32,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
@@ -106,7 +105,7 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<MXFlatmmArchTraits::GetNLane()>(b_origin_host);
|
||||
const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host);
|
||||
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
|
||||
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
|
||||
|
||||
@@ -151,7 +150,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_a_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
|
||||
@@ -57,6 +57,7 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
# add_subdirectory(layernorm2d)
|
||||
# add_subdirectory(rmsnorm2d)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(flatmm)
|
||||
add_subdirectory(gemm_mx)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(warp_gemm)
|
||||
|
||||
79
test/ck_tile/flatmm/CMakeLists.txt
Normal file
79
test/ck_tile/flatmm/CMakeLists.txt
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(TEST_FLATMM_COMPILE_OPTIONS)
|
||||
list(APPEND TEST_FLATMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND TEST_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
set(MXGEMM_EXAMPLE_DIR ${CMAKE_SOURCE_DIR}/example/ck_tile/18_flatmm/mxgemm)
|
||||
|
||||
# Generate the 40 kernel instance .cpp files.
|
||||
# We inline the generation here (rather than calling mx_flatmm_instance_generate)
|
||||
# so that configure_file paths resolve correctly from this directory.
|
||||
set(C_DATA_TYPE FP16)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
|
||||
set(FLATMM_INSTANCE_FILES)
|
||||
foreach(PERSISTENT false)
|
||||
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
|
||||
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
|
||||
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
|
||||
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
|
||||
set(ARCH MXFlatmm_GFX950_)
|
||||
set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits")
|
||||
foreach(SPLIT_K false)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${MXGEMM_EXAMPLE_DIR}/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND FLATMM_INSTANCE_FILES ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Compile the 20 kernel instances once into an object library,
|
||||
# shared across all 5 test executables to avoid redundant GPU compilation.
|
||||
# SPLIT_K=true instances are omitted: split-K is confirmed broken at the
|
||||
# kernel level for all dtype combinations and is not tested.
|
||||
add_library(mx_flatmm_test_instances OBJECT ${FLATMM_INSTANCE_FILES})
|
||||
target_include_directories(mx_flatmm_test_instances PRIVATE
|
||||
${MXGEMM_EXAMPLE_DIR}
|
||||
)
|
||||
target_compile_options(mx_flatmm_test_instances PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
foreach(DTYPE fp4fp4 fp8fp8 fp6fp6 fp8fp4 fp4fp8)
|
||||
add_gtest_executable(test_tile_mx_flatmm_${DTYPE}
|
||||
test_mx_flatmm_${DTYPE}.cpp
|
||||
)
|
||||
target_include_directories(test_tile_mx_flatmm_${DTYPE} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${MXGEMM_EXAMPLE_DIR}
|
||||
)
|
||||
target_compile_options(test_tile_mx_flatmm_${DTYPE} PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS})
|
||||
target_link_libraries(test_tile_mx_flatmm_${DTYPE} PRIVATE mx_flatmm_test_instances)
|
||||
endforeach()
|
||||
|
||||
# Umbrella target to build all flatmm tests at once
|
||||
add_custom_target(test_tile_mx_flatmm_all)
|
||||
add_dependencies(test_tile_mx_flatmm_all
|
||||
test_tile_mx_flatmm_fp4fp4
|
||||
test_tile_mx_flatmm_fp8fp8
|
||||
test_tile_mx_flatmm_fp6fp6
|
||||
test_tile_mx_flatmm_fp8fp4
|
||||
test_tile_mx_flatmm_fp4fp8
|
||||
)
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile flatmm tests for current target")
|
||||
endif()
|
||||
251
test/ck_tile/flatmm/test_mx_flatmm_base.hpp
Normal file
251
test/ck_tile/flatmm/test_mx_flatmm_base.hpp
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <cstring>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "mx_flatmm.hpp"
|
||||
|
||||
// Base class for MX Flatmm unit tests.
|
||||
//
|
||||
// Tuple layout: <ADataType, BDataType, CDataType, MXFlatmmArchTraits>
|
||||
template <typename Tuple>
|
||||
class TestMXFlatmmBase : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using CDataType = std::tuple_element_t<2, Tuple>;
|
||||
using MXFlatmmArchTraits = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
|
||||
using AccDataType = float;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
static constexpr int ScaleGranularityM = 1;
|
||||
static constexpr int ScaleGranularityN = 1;
|
||||
static constexpr int ScaleGranularityK = 32;
|
||||
|
||||
using ScaleA = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleB = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
void
|
||||
run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch = 1,
|
||||
std::optional<bool> expected_has_hot_loop = std::nullopt,
|
||||
std::optional<ck_tile::TailNumber> expected_tail_num = std::nullopt)
|
||||
{
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ASSERT_EQ(K % ScaleGranularityK, 0) << "K must be a multiple of ScaleGranularityK (32)";
|
||||
ASSERT_EQ(K % APackedSize, 0) << "K must be a multiple of A PackedSize";
|
||||
ASSERT_EQ(K % BPackedSize, 0) << "K must be a multiple of B PackedSize";
|
||||
|
||||
constexpr bool a_row_major = true;
|
||||
constexpr bool b_row_major = false;
|
||||
constexpr bool c_row_major = true;
|
||||
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, ck_tile::bool_constant<a_row_major>{});
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, ck_tile::bool_constant<b_row_major>{});
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, ck_tile::bool_constant<c_row_major>{});
|
||||
|
||||
const auto scale_stride_A = ck_tile::get_default_stride(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, 0, ck_tile::bool_constant<a_row_major>{});
|
||||
const auto scale_stride_B = ck_tile::get_default_stride(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, 0, ck_tile::bool_constant<b_row_major>{});
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, ck_tile::bool_constant<a_row_major>{}));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, ck_tile::bool_constant<b_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, ck_tile::bool_constant<c_row_major>{}));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_a(
|
||||
ck_tile::host_tensor_descriptor(M / ScaleGranularityM,
|
||||
K / ScaleGranularityK,
|
||||
scale_stride_A,
|
||||
ck_tile::bool_constant<a_row_major>{}));
|
||||
ck_tile::HostTensor<ScaleType> scale_b(
|
||||
ck_tile::host_tensor_descriptor(K / ScaleGranularityK,
|
||||
N / ScaleGranularityN,
|
||||
scale_stride_B,
|
||||
ck_tile::bool_constant<b_row_major>{}));
|
||||
|
||||
// Initialize data
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
|
||||
{
|
||||
// FP6: fill raw bytes with values 1..4 (avoids denormals)
|
||||
auto a_bytes = a_host.get_element_space_size_in_bytes();
|
||||
auto b_bytes = b_origin_host.get_element_space_size_in_bytes();
|
||||
std::vector<int8_t> buf_a(a_bytes), buf_b(b_bytes);
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<int> dis(1, 4);
|
||||
for(auto& v : buf_a)
|
||||
v = static_cast<int8_t>(dis(gen));
|
||||
for(auto& v : buf_b)
|
||||
v = static_cast<int8_t>(dis(gen));
|
||||
memcpy(a_host.data(), buf_a.data(), a_bytes);
|
||||
memcpy(b_origin_host.data(), buf_b.data(), b_bytes);
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
|
||||
// Preshuffle B and scales
|
||||
const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host);
|
||||
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
|
||||
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
|
||||
|
||||
// Device buffers
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
c_dev_buf.ToDevice(c_rslt_host.data());
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr = ScaleA{static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()),
|
||||
M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr = ScaleB{static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()),
|
||||
N / ScaleGranularityN};
|
||||
|
||||
// Build args
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args{a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffled_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_a_dev_ptr,
|
||||
scale_b_dev_ptr};
|
||||
|
||||
// Compute hot_loop / tail_num
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, FlatmmShape, GemmTraits>;
|
||||
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
if(expected_has_hot_loop.has_value())
|
||||
ASSERT_EQ(has_hot_loop, *expected_has_hot_loop)
|
||||
<< "has_hot_loop mismatch for (M=" << M << ", N=" << N << ", K=" << K << ")";
|
||||
if(expected_tail_num.has_value())
|
||||
ASSERT_EQ(tail_num, *expected_tail_num)
|
||||
<< "tail_num mismatch for (M=" << M << ", N=" << N << ", K=" << K << ")";
|
||||
|
||||
// Launch kernel (warmup=0, repeat=1 for correctness testing)
|
||||
// mx_flatmm_calc is explicitly instantiated in the linked object library;
|
||||
// suppress the -Wundefined-func-template warning that fires when the
|
||||
// compiler sees only the forward declaration in mx_flatmm.hpp.
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wundefined-func-template"
|
||||
BaseFlatmmPipeline::template TailHandler<true>(
|
||||
[&](auto has_hot_loop_, auto tail_num_) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
// SplitK (kbatch>1) is excluded: confirmed broken at the kernel level.
|
||||
// Always dispatch the kbatch=1 (SPLIT_K=false) path.
|
||||
mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
/*persistent=*/false,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
/*split_k=*/false,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
// CPU reference
|
||||
ck_tile::HostTensor<CDataType> c_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, ck_tile::bool_constant<c_row_major>{}));
|
||||
c_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_origin_host, c_ref, scale_a, scale_b);
|
||||
|
||||
const float rtol = 1e-2f;
|
||||
const float atol = 1e-2f;
|
||||
EXPECT_TRUE(
|
||||
ck_tile::check_err(c_rslt_host, c_ref, "MX Flatmm result mismatch", rtol, atol));
|
||||
}
|
||||
};
|
||||
20
test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp
Normal file
20
test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "test_mx_flatmm_base.hpp"
|
||||
#include "mx_flatmm_arch_traits.hpp"
|
||||
|
||||
// Convenience type aliases for use in test .cpp files
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP6 = ck_tile::pk_fp6x16_t;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
|
||||
// Concrete test fixture — inherits all logic from TestMXFlatmmBase.
|
||||
// Tuple layout: <ADataType, BDataType, CDataType, MXFlatmmArchTraits>
|
||||
template <typename Tuple>
|
||||
class TestMXFlatmm : public TestMXFlatmmBase<Tuple>
|
||||
{
|
||||
};
|
||||
41
test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp
Normal file
41
test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_mx_flatmm_fixtures.hpp"
|
||||
|
||||
// FP4 x FP4 -> FP16
|
||||
// N_Tile = 512 (MXfp4_FlatmmConfig16), so N must be a multiple of 512.
|
||||
// K must be a multiple of 32 (ScaleGranularityK) and 8 (FP4 PackedSize) -> multiple of 32.
|
||||
// clang-format off
|
||||
using FP4FP4Types = ::testing::Types<
|
||||
std::tuple<FP4, FP4, FP16, MXFlatmm_GFX950_FP4FP4_Traits>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestMXFlatmm, FP4FP4Types);
|
||||
|
||||
// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, SmallMNK)
|
||||
{
|
||||
this->run_test_with_validation(128, 512, 256, 1, false, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=512 -> num_loop=2: has_hot_loop=false, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, MediumMNK)
|
||||
{
|
||||
this->run_test_with_validation(256, 1024, 512, 1, false, ck_tile::TailNumber::Even);
|
||||
}
|
||||
|
||||
// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd)
|
||||
{
|
||||
this->run_test_with_validation(128, 512, 768, 1, true, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven)
|
||||
{
|
||||
this->run_test_with_validation(128, 512, 1024, 1, true, ck_tile::TailNumber::Even);
|
||||
}
|
||||
40
test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp
Normal file
40
test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_mx_flatmm_fixtures.hpp"
|
||||
|
||||
// FP4 x FP8 -> FP16
|
||||
// N_Tile = 256, K must be a multiple of lcm(32, 8) = 32.
|
||||
// clang-format off
|
||||
using FP4FP8Types = ::testing::Types<
|
||||
std::tuple<FP4, FP8, FP16, MXFlatmm_GFX950_FP4FP8_Traits>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestMXFlatmm, FP4FP8Types);
|
||||
|
||||
// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, SmallMNK)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=512 -> num_loop=2: has_hot_loop=false, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, MediumMNK)
|
||||
{
|
||||
this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even);
|
||||
}
|
||||
|
||||
// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even);
|
||||
}
|
||||
40
test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp
Normal file
40
test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_mx_flatmm_fixtures.hpp"
|
||||
|
||||
// FP6 x FP6 -> FP16
|
||||
// N_Tile = 256, K must be a multiple of lcm(32, 16) = 32 (FP6 PackedSize=16, lcm(32,16)=32).
|
||||
// clang-format off
|
||||
using FP6FP6Types = ::testing::Types<
|
||||
std::tuple<FP6, FP6, FP16, MXFlatmm_GFX950_FP6FP6_Traits>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestMXFlatmm, FP6FP6Types);
|
||||
|
||||
// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, SmallMNK)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=512 -> num_loop=2: has_hot_loop=false, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, MediumMNK)
|
||||
{
|
||||
this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even);
|
||||
}
|
||||
|
||||
// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even);
|
||||
}
|
||||
40
test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp
Normal file
40
test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_mx_flatmm_fixtures.hpp"
|
||||
|
||||
// FP8 x FP4 -> FP16
|
||||
// N_Tile = 256, K must be a multiple of lcm(32, 8) = 32.
|
||||
// clang-format off
|
||||
using FP8FP4Types = ::testing::Types<
|
||||
std::tuple<FP8, FP4, FP16, MXFlatmm_GFX950_FP8FP4_Traits>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestMXFlatmm, FP8FP4Types);
|
||||
|
||||
// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, SmallMNK)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=512 -> num_loop=2: has_hot_loop=false, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, MediumMNK)
|
||||
{
|
||||
this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even);
|
||||
}
|
||||
|
||||
// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even);
|
||||
}
|
||||
40
test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp
Normal file
40
test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_mx_flatmm_fixtures.hpp"
|
||||
|
||||
// FP8 x FP8 -> FP16
|
||||
// N_Tile = 256, K must be a multiple of 32.
|
||||
// clang-format off
|
||||
using FP8FP8Types = ::testing::Types<
|
||||
std::tuple<FP8, FP8, FP16, MXFlatmm_GFX950_FP8FP8_Traits>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestMXFlatmm, FP8FP8Types);
|
||||
|
||||
// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, SmallMNK)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=512 -> num_loop=2: has_hot_loop=false, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, MediumMNK)
|
||||
{
|
||||
this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even);
|
||||
}
|
||||
|
||||
// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd);
|
||||
}
|
||||
|
||||
// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even
|
||||
TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven)
|
||||
{
|
||||
this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even);
|
||||
}
|
||||
@@ -52,7 +52,12 @@ bool compare_results(std::string instanceName,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
std::abs(static_cast<float>(*std::max_element(c_m_n_host_result.mData.begin(),
|
||||
c_m_n_host_result.mData.end(),
|
||||
[](CDataType a, CDataType b) {
|
||||
return std::abs(static_cast<float>(a)) <
|
||||
std::abs(static_cast<float>(b));
|
||||
})));
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
|
||||
@@ -447,8 +447,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const float max_accumulated_value = std::abs(static_cast<float>(*std::max_element(
|
||||
c_m_n_host_ref.mData.begin(),
|
||||
c_m_n_host_ref.mData.end(),
|
||||
[](CDataType a, CDataType b) {
|
||||
return std::abs(static_cast<float>(a)) < std::abs(static_cast<float>(b));
|
||||
})));
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user