mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_TILE] Stream-K GEMM Implementation (#2781)
* Change splitk_batch_offset parameter to k_size in UniversalGemmKernel::MakeGemmTensorViews function Prior to this change, the splitk_batch_offset parameter of MakeGemmTensorViews had type SplitKBatchOffset. But, the only member variable of the SplitKBatchOffset class used in the MakeGemmTensorViews function was splitted_k (an int32_t). The splitted_k value was used as part of defining the dimensions of the tensor view. That said, for Stream K, we do not need to use the SplitKBatchOffset class since we are not using Split K. Thus, this commit changes the splitk_batch_offset parameter to a int32_t called k_size. This will avoid the constraint of requiring a caller of MakeGemmTensorViews to use the SplitKBatchOffset class while still providing the same functionality. Calls to UniversalGemmKernel::MakeGemmTensorViews have been updated accordingly. * StreamK Kernel RunGemm Implementation Stream K cannot simply use UniversalGemmKernel's RunGemm for the following reasons: 1. The UniversalGemmKernel::RunGemm function computes num_loop based on a static function of the TilePartitioner. That said, for Stream K, num_loop must be computed using a member function (namely GetCurrentIterLength from PR #2708). 2. The UniversalGemmKernel::RunGemm function requires the use of a SplitKBatchOffset object which is not used for Stream K since we are not using Split K. Thus, this change adds a RunGemm function in the StreamKKernel class. * initial implementation for operator() for StreamKKernel: adding stream-k algorithm and calls to RunGemm * Fix indexing and offset issues for StreamK These changes do the following: - Ensure offsets along the M and N dimensions are multiplied by MPerblock or NPerBlock, respectively. This ensures tile window origins are at the correct locations. - Fix bug in the tile partitioner's GetTileIdxWithOffset. Now, we apply divmod to the given references to ensure correct values are available to the caller. - Added documentation in the Stream-K operator() * Initial gtests for Stream-K These changes add an initial gtest suite for the CK Tile Stream-K kernel. Currently, due to bugs in the StreamKTilePartitioner (which will be handled in a future PR), there are validation issues for certain cases which may differ on different architectures. Thus, we opted to run cases that are only fully data-parallel (skipping others). A guard was added to Stream-K's IsSupportedArgument method to ensure that callers are aware of this constraint. Additionally, to ensure testing reproducibility, options for setting the number of CUs and occupancy were added to MakeKernelArgs. * Use GemmPipeline operator() variant that takes hot loop and tail num In Stream-K, the num_loop value varies per WG and per iteration of a Stream-K loop. So instead, we use the version of the GemmPipeline's operator() function that takes in has_hot_loop and tail_num. This is similar to what is done in Grouped GEMM. * changes from review: comments, move readfirstlane, remove ifndef * Switch direction of C tensor traversal & add padding guard Prior to this change, WGs travelled backwards through their assigned macro tiles in the C tensor. For instance, if WG0 is responsible for C tiles 0 and 1, it would first visit tile 1 then tile 0. This means that the iter_end decrements in each iteration of the stream-K while loop. Since we are working with unsigned integers, the subtraction operation may not be safe. Thus, this change makes is such that WGs travel forward so that their iter_start is incremented and their iter_end remains fixed. Additionally, we added a guard against WGs that are neither sk_blocks nor dp_blocks to ensure such WGs do not participate in the GEMM. Together, these changes make is such that the algorithm is correct when sk_blocks is greater than zero. * Disable StreamK_M256_N256_K256_SKBlocks12 test case This instance involves >=3 WGs contributing to each macro tile in C. Due to the use of atomics, this is resulting in precision errors. These errors will not persist once the reduction strategy is implemented. We will re-enable this test then. --------- Co-authored-by: Astha Rai <astha.rai713@gmail.com>
This commit is contained in:
7
test/ck_tile/gemm_streamk/CMakeLists.txt
Normal file
7
test/ck_tile/gemm_streamk/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Currently test_ck_tile_streamk is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
#TODO: support all arches
|
||||
add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp)
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
14
test/ck_tile/gemm_streamk/test_gemm_streamk.cpp
Normal file
14
test/ck_tile/gemm_streamk/test_gemm_streamk.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_types.hpp"
|
||||
#include "test_gemm_streamk_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamK
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
118
test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc
Normal file
118
test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 4;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
// TODO: Renable this test once reduction is implemented
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs "
|
||||
"contributing to each macro tile in C";
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 12;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 8;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 16;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 8;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 64;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 64;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction),
|
||||
std::runtime_error);
|
||||
}
|
||||
25
test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp
Normal file
25
test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesStreamK = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
282
test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp
Normal file
282
test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are
|
||||
// resolved. Because the number of WGs contributing to a macro tile in C may not be the same for
|
||||
// all macro tiles in C.
|
||||
|
||||
// Calculate error due to more than 1 WG contributing to the same macro tile in C
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamK : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
|
||||
bool PadM = true,
|
||||
bool PadN = true,
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
void invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool StructuredSparsity = false;
|
||||
constexpr bool NumWaveGroup = 1;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
false,
|
||||
NumWaveGroup,
|
||||
preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
// For initial testing, we will just test with one pipeline.
|
||||
// More extensive testing is coming later and will test other pipelines.
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
EXPECT_TRUE(false);
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 block_dims = Kernel::BlockSize();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
};
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the same
|
||||
// output tile in the C tensor, so we must atomic add
|
||||
// the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to
|
||||
// 104 and occupancy to 1 to ensure tests are reproducible on different architectures.
|
||||
void Run(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
uint32_t num_sk_blocks = 0xffffffff,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
int occupancy = 1,
|
||||
int num_cu = 104,
|
||||
ck_tile::index_t stride_A = 0,
|
||||
ck_tile::index_t stride_B = 0,
|
||||
ck_tile::index_t stride_C = 0)
|
||||
{
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
|
||||
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
|
||||
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, /*seed*/ 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, /*seed*/ 11940}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
};
|
||||
};
|
||||
Reference in New Issue
Block a user