This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(tile_example_gemm_multi_abd_fp16 gemm_multi_abd_fp16.cpp)

View File

@@ -0,0 +1,35 @@
#Multiple ABD GEMM
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
## build
```
#in the root of ck_tile
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_abd_fp16 -j
```
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
## example
```
args:
-m M dimensions - (Default: 3840)
-n N dimensions - (Default: 4096)
-k K dimensions - (Default: 4096)
-as_layout Tensor A layout (default:R)
-bs_layout Tensor B layout (default:C)
-ds_layout Tensor D layout (default:R)
-e_layout Tensor E layout (default:R)
-stride_as Tensor A strides - (Default: 0)
-stride_bs Tensor B strides - (Default: 0)
-stride_e Tensor C strides - (Default: 0)
-stride_ds Tensor D strides - (Default: 0)
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
-warmup Number of iterations before benchmark the kernel. (Default: 10)
-repeat Number of iterations to benchmark the kernel. (Default: 100)
-kbatch kbatch for SplitK. (Default: 1)
```

View File

@@ -0,0 +1,137 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_multi_abd_fp16.hpp"
#include "utils.hpp"
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float
{
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
AsLayout,
BsLayout,
ELayout,
TransposeC>;
constexpr auto scheduler = GemmConfig::Scheduler;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
BsDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
AElementWise,
BElementWise>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "run_gemm_multi_abd_fp16_example.inc"
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_multiple_abd_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_multiple_abd_gemm_example<GemmConfigV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,179 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
using A0DataType = ck_tile::half_t;
using A1DataType = ck_tile::half_t;
using B0DataType = ck_tile::half_t;
using B1DataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
struct GemmConfigMemory
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("as_layout", "R", "As tensor data layout - Row by default")
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")
.insert("stride_as", "0", "Tensor A stride")
.insert("stride_bs", "0", "Tensor B stride")
.insert("stride_ds", "0", "Tensor Ds stride")
.insert("stride_e", "0", "Tensor E stride")
.insert("v", "1", "0. No validation, 1. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
using gemm_multi_abd_kargs =
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>;
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename CLayout,
typename AElementWise,
typename BElementWise,
typename CDEElementWise>
float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,311 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstddef>
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
void* e_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
ck_tile::index_t StrideE,
int n_warmup,
int n_repeat,
int k_batch)
{
gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf,
bs_k_n_dev_buf,
ds_m_n_dev_buf,
e_m_n_dev_buf,
k_batch,
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE});
float ave_time = gemm_multi_abd<GemmConfig,
AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementWise,
BElementWise,
CDEElementWise>(
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm Multiple-ABD"};
std::size_t flop = 0, num_btype = 0;
flop += std::size_t(2) * M * N * K;
num_btype +=
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0]
<< " StrideE = " << StrideE << "\n";
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< "\n";
return ave_time;
}
template <typename GemmConfig,
typename A0Layout,
typename A1Layout,
typename B0Layout,
typename B1Layout,
typename D0Layout,
typename D1Layout,
typename ELayout>
int run_gemm_multi_abd_example_with_layouts(int argc,
char* argv[],
const A0Layout a0_layout = A0Layout{},
const A1Layout a1_layout = A1Layout{},
const B0Layout b0_layout = B0Layout{},
const B1Layout b1_layout = B1Layout{},
const D0Layout d0_layout = D0Layout{},
const D1Layout d1_layout = D1Layout{},
const ELayout e_layout = ELayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
using AElementWiseFn = ck_tile::element_wise::AddScale;
using BElementWiseFn = ck_tile::element_wise::AddScale;
using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply;
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
ck_tile::index_t StrideA0 = StrideA;
ck_tile::index_t StrideA1 = StrideA;
ck_tile::index_t StrideB0 = StrideB;
ck_tile::index_t StrideB1 = StrideB;
ck_tile::index_t StrideD0 = StrideD;
ck_tile::index_t StrideD1 = StrideD;
const int n_warmup = arg_parser.get_int("warmup");
const int n_repeat = arg_parser.get_int("repeat");
const int k_batch = arg_parser.get_int("kbatch");
StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout));
StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout));
StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout));
StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout));
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout)));
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout)));
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
ck_tile::FillUniformDistribution<A1DataType>{-1.f, 1.f}(a1_m_k_tesnor);
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
a1_m_k_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
b1_k_n_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
invoke_gemm_multi_abd<GemmConfig,
AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>(as_ptr_buf,
bs_ptr_buf,
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
strideAs,
strideBs,
strideDs,
StrideE,
n_warmup,
n_repeat,
k_batch);
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
a_m_k_host_ref_element_result.SetZero();
b_k_n_host_ref_element_result.SetZero();
e_m_n_host_ref.SetZero();
ck_tile::reference_gemm_multiple_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
{b0_k_n_tensors, b1_k_n_tensors},
{d0_m_n_tensors, d1_m_n_tensors},
a_m_k_host_ref_element_result,
b_k_n_host_ref_element_result,
e_m_n_host_ref);
bool pass{true};
if(arg_parser.get_int("v"))
{
const float max_accumulated_value =
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
pass &= ck_tile::check_err(e_m_n_device_result,
e_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< std::endl;
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename GemmConfig>
int run_multiple_abd_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string as_layout = arg_parser.get_str("as_layout");
const std::string bs_layout = arg_parser.get_str("bs_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(as_layout == "R" && bs_layout == "C")
{
return run_gemm_multi_abd_example_with_layouts<GemmConfig>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}