mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '2723dbd33245b76bfe716c5adc8c9fb577a4b68f' into develop
This commit is contained in:
@@ -1,3 +1,10 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
|
||||
|
||||
|
||||
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
|
||||
@@ -887,6 +887,58 @@ struct tile_window_with_static_lengths
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Print tile window elements for debugging.
|
||||
*
|
||||
* @tparam DataType Element data type (e.g., fp16_t, float, bf8_t)
|
||||
* @param start_i Starting row (inclusive)
|
||||
* @param end_i Ending row (exclusive)
|
||||
* @param start_j Starting column (inclusive)
|
||||
* @param end_j Ending column (exclusive)
|
||||
* @param label Optional output label
|
||||
*
|
||||
* @note Tested on fp16. Custom types may need adjustments.
|
||||
* @example tile_window.template print_tile_window_range<fp16_t>(0, 4, 0, 8, "A");
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE void print_tile_window_range(index_t start_i,
|
||||
index_t end_i,
|
||||
index_t start_j,
|
||||
index_t end_j,
|
||||
const char* label = "") const
|
||||
{
|
||||
const auto& tensor_view = this->get_bottom_tensor_view();
|
||||
const auto window_origin = this->get_window_origin();
|
||||
|
||||
printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
|
||||
label,
|
||||
start_i,
|
||||
end_i - 1,
|
||||
start_j,
|
||||
end_j - 1,
|
||||
window_origin[0],
|
||||
window_origin[1]);
|
||||
|
||||
for(index_t i = start_i; i < end_i; i++)
|
||||
{
|
||||
for(index_t j = start_j; j < end_j; j++)
|
||||
{
|
||||
// Create coordinate for this element relative to window origin
|
||||
auto coord =
|
||||
make_tensor_coordinate(tensor_view.get_tensor_descriptor(),
|
||||
make_tuple(window_origin[0] + i, window_origin[1] + j));
|
||||
|
||||
// Get the element using thread buffer type directly
|
||||
using ThreadBuf = thread_buffer<DataType, 2>;
|
||||
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
|
||||
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
|
||||
printf(" %s[%d,%d] = %f", label, i, j, static_cast<float>(value));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
|
||||
@@ -646,16 +646,13 @@ struct StreamKTilePartitioner
|
||||
* @brief Get length of loop iterations for stream-k loop
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const noexcept
|
||||
uint32_t iter_end) const noexcept
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t total_iter_length_val = static_cast<uint32_t>(total_iter_length);
|
||||
uint32_t current_iter_length =
|
||||
min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
|
||||
total_iter_length_val);
|
||||
return current_iter_length;
|
||||
// A WG's iter_end is either in the current C macro tile or not.
|
||||
// If it is not, then the macro tile boundary is where the WG must stop.
|
||||
uint32_t distance_to_tile_boundary =
|
||||
k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get());
|
||||
return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -672,9 +669,7 @@ struct StreamKTilePartitioner
|
||||
CK_TILE_DEVICE void
|
||||
GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
|
||||
{
|
||||
uint32_t tile_idx_val = static_cast<uint32_t>(tile_idx);
|
||||
uint32_t iter_offset_val = static_cast<uint32_t>(iter_offset);
|
||||
k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val);
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -374,7 +374,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
@@ -436,7 +436,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
|
||||
@@ -141,11 +141,17 @@ struct StreamKKernel
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args)
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
uint32_t occupancy = static_cast<uint32_t>(Occupancy());
|
||||
uint32_t num_cu = static_cast<uint32_t>(NumCU());
|
||||
|
||||
return StreamKKernelArgs{{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
@@ -166,14 +172,71 @@ struct StreamKKernel
|
||||
TilePartitioner{static_cast<uint32_t>(host_args.M),
|
||||
static_cast<uint32_t>(host_args.N),
|
||||
static_cast<uint32_t>(host_args.K),
|
||||
num_cu,
|
||||
occupancy,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
host_args.num_sk_blocks}};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs)
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -199,9 +262,81 @@ struct StreamKKernel
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
// Temporary placeholder to support the Occupancy() static function.
|
||||
// Since the Occupancy function uses kentry, this class must have an operator() function
|
||||
CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {}
|
||||
/// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop.
|
||||
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
uint32_t block_idx = ck_tile::get_block_1d_id();
|
||||
|
||||
bool is_padding_block =
|
||||
__builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
|
||||
block_idx < kargs.tile_partitioner.dp_start_block_idx);
|
||||
|
||||
// Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
|
||||
// should not partake in the GEMM
|
||||
if(is_padding_block)
|
||||
return;
|
||||
|
||||
// Determine the K offset of the first and final macro tile in the A and B tensors along the
|
||||
// K dimension.
|
||||
uint32_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
|
||||
|
||||
// Main Stream-K loop
|
||||
while(true)
|
||||
{
|
||||
// Determine the number of macro tiles in A and B this WG is resposible for in the
|
||||
// current C macro tile.
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
|
||||
|
||||
// Determine the 1D tile_idx and the iter_offset for this WG.
|
||||
// The tile_idx is the 1D macro tile index in the C tensor.
|
||||
// The iter_offset is the starting macro tile index in the K dimension for the WG in the
|
||||
// current iteration of the while loop.
|
||||
uint32_t tile_idx, iter_offset;
|
||||
kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
|
||||
|
||||
// Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
|
||||
auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
|
||||
|
||||
// Get the offsets in A, B, C tensors.
|
||||
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
|
||||
TilePartitioner::MPerBlock);
|
||||
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
|
||||
TilePartitioner::NPerBlock);
|
||||
index_t i_k = static_cast<index_t>(iter_offset) * TilePartitioner::KPerBlock;
|
||||
|
||||
// Determine the total size along the K dimension the WG is using in this iteration
|
||||
// (used to construct tensor views).
|
||||
index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
|
||||
|
||||
// Update pointer offsets for A, B, and C.
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
current_iter_length,
|
||||
i_m,
|
||||
i_n,
|
||||
k_size);
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start += current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_HOST static int NumCU()
|
||||
|
||||
@@ -579,7 +579,7 @@ struct UniversalGemmKernel
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
const index_t k_size)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
@@ -591,7 +591,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -600,7 +600,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
@@ -638,7 +638,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
@@ -672,7 +672,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
(k_size /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
@@ -687,7 +687,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
@@ -962,7 +962,7 @@ struct UniversalGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
@@ -1018,7 +1018,7 @@ struct UniversalGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -3,7 +3,9 @@ add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_weight_preshuffle)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(grouped_gemm_preshuffle)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_streamk)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(container)
|
||||
add_subdirectory(elementwise)
|
||||
|
||||
7
test/ck_tile/gemm_streamk/CMakeLists.txt
Normal file
7
test/ck_tile/gemm_streamk/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# Currently test_ck_tile_streamk is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
#TODO: support all arches
|
||||
add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp)
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
14
test/ck_tile/gemm_streamk/test_gemm_streamk.cpp
Normal file
14
test/ck_tile/gemm_streamk/test_gemm_streamk.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_types.hpp"
|
||||
#include "test_gemm_streamk_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamK
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
118
test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc
Normal file
118
test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 4;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
// TODO: Renable this test once reduction is implemented
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs "
|
||||
"contributing to each macro tile in C";
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 12;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
uint32_t num_sk_blocks = 8;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 16;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 512;
|
||||
ck_tile::index_t K = 512;
|
||||
uint32_t num_sk_blocks = 8;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 0;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 64;
|
||||
|
||||
this->Run(M, N, K, num_sk_blocks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction)
|
||||
{
|
||||
|
||||
ck_tile::index_t M = 3840;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
uint32_t num_sk_blocks = 64;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction),
|
||||
std::runtime_error);
|
||||
}
|
||||
25
test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp
Normal file
25
test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesStreamK = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
282
test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp
Normal file
282
test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are
|
||||
// resolved. Because the number of WGs contributing to a macro tile in C may not be the same for
|
||||
// all macro tiles in C.
|
||||
|
||||
// Calculate error due to more than 1 WG contributing to the same macro tile in C
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamK : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
|
||||
bool PadM = true,
|
||||
bool PadN = true,
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
void invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool StructuredSparsity = false;
|
||||
constexpr bool NumWaveGroup = 1;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
false,
|
||||
NumWaveGroup,
|
||||
preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
// For initial testing, we will just test with one pipeline.
|
||||
// More extensive testing is coming later and will test other pipelines.
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
EXPECT_TRUE(false);
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 block_dims = Kernel::BlockSize();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
};
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the same
|
||||
// output tile in the C tensor, so we must atomic add
|
||||
// the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to
|
||||
// 104 and occupancy to 1 to ensure tests are reproducible on different architectures.
|
||||
void Run(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
uint32_t num_sk_blocks = 0xffffffff,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
int occupancy = 1,
|
||||
int num_cu = 104,
|
||||
ck_tile::index_t stride_A = 0,
|
||||
ck_tile::index_t stride_B = 0,
|
||||
ck_tile::index_t stride_C = 0)
|
||||
{
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
|
||||
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
|
||||
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, /*seed*/ 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, /*seed*/ 11940}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
};
|
||||
};
|
||||
9
test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt
Normal file
9
test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_preshuffle_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// Custom tuple-like structure for kernel configuration
|
||||
template <typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
int M_Tile_val_,
|
||||
int N_Tile_val_,
|
||||
int K_Tile_val_,
|
||||
int BlockPerCu_val_>
|
||||
struct KernelConfig
|
||||
{
|
||||
using ALayoutType = ALayout_;
|
||||
using BLayoutType = BLayout_;
|
||||
using CLayoutType = CLayout_;
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
static constexpr int K_Tile_ = K_Tile_val_;
|
||||
static constexpr int BlockPerCu_ = BlockPerCu_val_;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmPreshuffle, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_preshuffle_ut_cases.inc"
|
||||
#include "test_grouped_gemm_preshuffle_prefill_cases.inc"
|
||||
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
// Test with prefill config struct
|
||||
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, PrefillVariant)
|
||||
{
|
||||
const int group_count = 4;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 128 * i);
|
||||
Ns.push_back(256 + 128 * i);
|
||||
Ks.push_back(128 * (i + 1));
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, VariedDimensions)
|
||||
{
|
||||
const int group_count = 6;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
std::vector<std::tuple<int, int, int>> test_cases = {{64, 128, 256},
|
||||
{128, 256, 512},
|
||||
{256, 512, 1024},
|
||||
{512, 256, 128},
|
||||
{128, 128, 128},
|
||||
{64, 512, 256}};
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
auto [M, N, K] = test_cases[i];
|
||||
Ms.push_back(M);
|
||||
Ns.push_back(N);
|
||||
Ks.push_back(K);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
// kPadK is not needed for these k values
|
||||
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKFalse)
|
||||
{
|
||||
const int group_count = 4;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 256 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
// kPadK is needed to be true for these k values
|
||||
TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKTrue)
|
||||
{
|
||||
const int group_count = 4;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = typename Tuple::ALayoutType;
|
||||
using BLayout = typename Tuple::BLayoutType;
|
||||
using CLayout = typename Tuple::CLayoutType;
|
||||
using ADataType = typename Tuple::ADataType;
|
||||
using BDataType = typename Tuple::BDataType;
|
||||
using AccDataType = typename Tuple::AccDataType;
|
||||
using CDataType = typename Tuple::CDataType;
|
||||
using PrecType = BDataType;
|
||||
using DsLayout = ck_tile::tuple<>; // not used
|
||||
using DsDataType = ck_tile::tuple<>; // not used
|
||||
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = true; // preshuffle pipeline requires k padding
|
||||
|
||||
static const int kBlockPerCu = Tuple::BlockPerCu_;
|
||||
|
||||
// Tile dimensions from tuple
|
||||
static const ck_tile::index_t M_Tile = Tuple::M_Tile_;
|
||||
static const ck_tile::index_t N_Tile = Tuple::N_Tile_;
|
||||
static const ck_tile::index_t K_Tile = Tuple::K_Tile_;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 1;
|
||||
static const ck_tile::index_t N_Warp = 4;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
|
||||
static constexpr bool TransposeC = false; // transpose c is not supported
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
// for testing purposes, we can hardcode the values here as we what is compatible with
|
||||
// pipeline
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ false,
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
ck_tile::GemmPipelineScheduler::Default,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using GemmPipeline =
|
||||
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// EXPECT TO FAIL because splitk is not supported
|
||||
EXPECT_FALSE(true);
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
const int kbatch = 1,
|
||||
const int group_count = 16)
|
||||
{
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{});
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
// Host-side preshuffle of B
|
||||
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
|
||||
// Copy results back to host for validation
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user