mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
9
example/ck_tile/19_gemm_multi_d/CMakeLists.txt
Normal file
9
example/ck_tile/19_gemm_multi_d/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_gemm_multi_d_fp16 gemm_multi_d_fp16.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_gemm_multi_d_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
97
example/ck_tile/19_gemm_multi_d/README.md
Normal file
97
example/ck_tile/19_gemm_multi_d/README.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# Multiple D GEMM with CK Tile
|
||||
|
||||
This example demonstrates GEMM with multiple D tensors (multi-output GEMM) using the CK Tile programming model. This is useful for fused operations where the GEMM output is combined with multiple side inputs (e.g., bias, residual, or other elementwise sources).
|
||||
|
||||
---
|
||||
|
||||
## Algorithm and Math
|
||||
|
||||
Given:
|
||||
- $A$: $[M, K]$
|
||||
- $B$: $[K, N]$
|
||||
- $D_0, D_1, ..., D_n$: $[M, N]$ (multiple side inputs)
|
||||
- $E$: $[M, N]$ (output)
|
||||
|
||||
The operation:
|
||||
$$
|
||||
E = f(A \times B, D_0, D_1, ..., D_n)
|
||||
$$
|
||||
where $f$ is a fused elementwise function (e.g., add, multiply, activation).
|
||||
|
||||
- **Tilewise Multi-D GEMM**: Each thread block processes a tile of $E$, loading corresponding tiles from $A$, $B$, and all $D_i$, performing blockwise GEMM and fused elementwise operations.
|
||||
|
||||
---
|
||||
|
||||
## Tile Programming Model
|
||||
|
||||
- **Tiles**: Each thread block processes a tile of $E$.
|
||||
- **Pipeline**: Modular, supports different memory/computation pipelines and multi-D fusion.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple D Inputs**: Supports arbitrary number of side inputs for fusion.
|
||||
- **Flexible Layouts**: Supports row/column-major and custom strides for all tensors.
|
||||
- **SplitK**: Supports K-batching for large K dimensions.
|
||||
- **Validation**: GPU validation and benchmarking options.
|
||||
|
||||
---
|
||||
|
||||
## Build & Run
|
||||
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_d_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16`
|
||||
|
||||
### Arguments
|
||||
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:4096)
|
||||
-a_layout A tensor data layout - Row by default (default:R)
|
||||
-b_layout B tensor data layout - Col by default (default:C)
|
||||
-ds_layout Ds tensor data layout - Row by default (default:R)
|
||||
-e_layout E tensor data layout - Row by default (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_ds Tensor Ds stride (default:0)
|
||||
-stride_e Tensor E stride (default:0)
|
||||
-v 0. No validation, 1. Validation on GPU (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-kbatch kbatch for SplitK (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:cktile_gemm_multi_d_fp16.json)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Source Structure
|
||||
|
||||
- **Kernel**: [`gemm_multi_d_fp16.hpp`](gemm_multi_d_fp16.hpp) (tile-programming kernel template)
|
||||
- **Executable**: [`gemm_multi_d_fp16.cpp`](gemm_multi_d_fp16.cpp)
|
||||
- **Utils**: [`utils.hpp`](utils.hpp)
|
||||
- **Build**: `CMakeLists.txt`, `run_gemm_multi_d_fp16_example.inc`
|
||||
|
||||
---
|
||||
|
||||
## Related CK Tile Examples
|
||||
|
||||
- [03_gemm](../03_gemm/README.md): Single GEMM with tiles
|
||||
- [16_batched_gemm](../16_batched_gemm/README.md): Batched GEMM with tiles
|
||||
- [17_grouped_gemm](../17_grouped_gemm/README.md): Grouped GEMM with tiles
|
||||
|
||||
For distribution, see [`include/ck_tile/tile_engine/`](../../../include/ck_tile/tile_engine/) and [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
|
||||
|
||||
---
|
||||
[Back to CK Tile Examples](../README.md)
|
||||
134
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp
Normal file
134
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_d_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_multiple_d_gemm_example<GemmConfigV3_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_multiple_d_gemm_example<GemmConfigV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
170
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp
Normal file
170
example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
struct GemmConfigMemory
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "cktile_gemm_multi_d_fp16.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs<DsDataType::size()>;
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise>
|
||||
float gemm_multi_d(const gemm_multi_d_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
|
||||
const void* b_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t StrideA,
|
||||
ck_tile::index_t StrideB,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
gemm_multi_d_kargs gemm_descs({a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
ds_m_n_dev_buf,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
float ave_time = gemm_multi_d<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_multiple_d_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
using CDElementWiseFn = MultiplyMultiply;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideD0 = StrideD;
|
||||
ck_tile::index_t StrideD1 = StrideD;
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA = get_default_stride(M, K, StrideA, is_row_major(a_layout));
|
||||
StrideB = get_default_stride(K, N, StrideB, is_row_major(b_layout));
|
||||
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
|
||||
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
|
||||
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data());
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
|
||||
|
||||
float ave_time = invoke_gemm_multi_d<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
std::string op_name{"Gemm Multiple-D"};
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
});
|
||||
|
||||
num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-D kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
|
||||
<< "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
CDElementWiseFn>(
|
||||
a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref);
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v"))
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
pass &= ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_gemm_multi_d_fp16_json_results(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideD0,
|
||||
StrideD1,
|
||||
StrideE,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_multiple_d_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string ds_layout = arg_parser.get_str("ds_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
|
||||
{
|
||||
return run_multiple_d_gemm_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
|
||||
}
|
||||
}
|
||||
50
example/ck_tile/19_gemm_multi_d/utils.hpp
Normal file
50
example/ck_tile/19_gemm_multi_d/utils.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
Reference in New Issue
Block a user