mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
4
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
4
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_gemm_multi_abd_fp16 gemm_multi_abd_fp16.cpp)
|
||||
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
#Multiple ABD GEMM
|
||||
|
||||
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_abd_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-as_layout Tensor A layout (default:R)
|
||||
-bs_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_as Tensor A strides - (Default: 0)
|
||||
-stride_bs Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default: 1)
|
||||
```
|
||||
137
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
137
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_abd_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
179
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
179
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
using A0DataType = ck_tile::half_t;
|
||||
using A1DataType = ck_tile::half_t;
|
||||
|
||||
using B0DataType = ck_tile::half_t;
|
||||
using B1DataType = ck_tile::half_t;
|
||||
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
|
||||
using EDataType = ck_tile::half_t;
|
||||
|
||||
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
|
||||
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
struct GemmConfigMemory
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("as_layout", "R", "As tensor data layout - Row by default")
|
||||
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_as", "0", "Tensor A stride")
|
||||
.insert("stride_bs", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
using gemm_multi_abd_kargs =
|
||||
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>;
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise,
|
||||
typename BElementWise,
|
||||
typename CDEElementWise>
|
||||
float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,311 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
|
||||
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
|
||||
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf,
|
||||
bs_k_n_dev_buf,
|
||||
ds_m_n_dev_buf,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
float ave_time = gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWise,
|
||||
BElementWise,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm Multiple-ABD"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
num_btype +=
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0]
|
||||
<< " StrideE = " << StrideE << "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename A0Layout,
|
||||
typename A1Layout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_gemm_multi_abd_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const A0Layout a0_layout = A0Layout{},
|
||||
const A1Layout a1_layout = A1Layout{},
|
||||
const B0Layout b0_layout = B0Layout{},
|
||||
const B1Layout b1_layout = B1Layout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
using AElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using BElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply;
|
||||
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
|
||||
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideA0 = StrideA;
|
||||
ck_tile::index_t StrideA1 = StrideA;
|
||||
|
||||
ck_tile::index_t StrideB0 = StrideB;
|
||||
ck_tile::index_t StrideB1 = StrideB;
|
||||
|
||||
ck_tile::index_t StrideD0 = StrideD;
|
||||
ck_tile::index_t StrideD1 = StrideD;
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout));
|
||||
StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout));
|
||||
|
||||
StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout));
|
||||
StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout));
|
||||
|
||||
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
|
||||
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
|
||||
|
||||
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout)));
|
||||
|
||||
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout)));
|
||||
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<A1DataType>{-1.f, 1.f}(a1_m_k_tesnor);
|
||||
|
||||
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
|
||||
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
|
||||
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
|
||||
|
||||
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
|
||||
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
|
||||
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a1_m_k_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b1_k_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
|
||||
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>(as_ptr_buf,
|
||||
bs_ptr_buf,
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
strideAs,
|
||||
strideBs,
|
||||
strideDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
a_m_k_host_ref_element_result.SetZero();
|
||||
b_k_n_host_ref_element_result.SetZero();
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
|
||||
{b0_k_n_tensors, b1_k_n_tensors},
|
||||
{d0_m_n_tensors, d1_m_n_tensors},
|
||||
a_m_k_host_ref_element_result,
|
||||
b_k_n_host_ref_element_result,
|
||||
e_m_n_host_ref);
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v"))
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
pass &= ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_multiple_abd_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string as_layout = arg_parser.get_str("as_layout");
|
||||
const std::string bs_layout = arg_parser.get_str("bs_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(as_layout == "R" && bs_layout == "C")
|
||||
{
|
||||
return run_gemm_multi_abd_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
Reference in New Issue
Block a user