[CK_TILE] Multiple-ABD GEMM example (#2788)

* Multi ABD - initial commit

* Clang-foramt fix

* block gemm, unify the name of CDataType

* Apply chnages to mem-pipeline

* Rollback prefix for DType and Layout

* Gemm Kernel Basic, rename

* WMMA config

* Grouped GEMM

* Clang-format

* Dropout, name

* Review v2

* Move element_wise fn to unnary, remov old ones fn

* clang-format

* Fix issue review

* WP operator adjust to universal gemm

* v2 prepare

* Remove unused comment

* Remove vectorsize

* Rollback

* Adjust pipeline for abd

* Shuffle argument

* CI-fail fix quant

* Fix ag_br pipeline

* Failing tests

* Typo

* Single argument support
This commit is contained in:
Mateusz Ozga
2025-09-19 01:14:11 +02:00
committed by GitHub
parent 14bbc545ea
commit 30ab1d6a71
41 changed files with 3603 additions and 552 deletions

View File

@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added support for Multiple D GEMM
* Added support for Multiple ABD GEMM
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.

View File

@@ -0,0 +1 @@
add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp)

View File

@@ -0,0 +1,35 @@
#Multiple ABD GEMM
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
## build
```
#in the root of ck_tile
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_abd_fp16 -j
```
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
## example
```
args:
-m M dimensions - (Default: 3840)
-n N dimensions - (Default: 4096)
-k K dimensions - (Default: 4096)
-as_layout Tensor A layout (default:R)
-bs_layout Tensor B layout (default:C)
-ds_layout Tensor D layout (default:R)
-e_layout Tensor E layout (default:R)
-stride_as Tensor A strides - (Default: 0)
-stride_bs Tensor B strides - (Default: 0)
-stride_e Tensor C strides - (Default: 0)
-stride_ds Tensor D strides - (Default: 0)
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
-warmup Number of iterations before benchmark the kernel. (Default: 10)
-repeat Number of iterations to benchmark the kernel. (Default: 100)
-kbatch kbatch for SplitK. (Default: 1)
```

View File

@@ -0,0 +1,184 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_multi_abd_fp16.hpp"
#include "utils.hpp"
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float
{
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
AsLayout,
BsLayout,
ELayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
BsDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
AElementWise,
BElementWise>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
<< blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_gemm_multi_abd_fp16_example.inc"
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_multiple_abd_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_multiple_abd_gemm_example<GemmConfigV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,186 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
using A0DataType = ck_tile::half_t;
using A1DataType = ck_tile::half_t;
using B0DataType = ck_tile::half_t;
using B1DataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
struct GemmConfigMemory
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("as_layout", "R", "As tensor data layout - Row by default")
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")
.insert("stride_as", "0", "Tensor A stride")
.insert("stride_bs", "0", "Tensor B stride")
.insert("stride_ds", "0", "Tensor Ds stride")
.insert("stride_e", "0", "Tensor E stride")
.insert("v", "1", "0. No validation, 1. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
using gemm_multi_abd_kargs =
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>;
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename CLayout,
typename AElementWise,
typename BElementWise,
typename CDEElementWise>
float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,311 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
template <typename GemmConfig,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
void* e_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
ck_tile::index_t StrideE,
int n_warmup,
int n_repeat,
int k_batch)
{
gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf,
bs_k_n_dev_buf,
ds_m_n_dev_buf,
e_m_n_dev_buf,
k_batch,
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE});
float ave_time = gemm_multi_abd<GemmConfig,
AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementWise,
BElementWise,
CDEElementWise>(
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm Multiple-ABD"};
std::size_t flop = 0, num_btype = 0;
flop += std::size_t(2) * M * N * K;
num_btype +=
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0]
<< " StrideE = " << StrideE << "\n";
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< "\n";
return ave_time;
}
template <typename GemmConfig,
typename A0Layout,
typename A1Layout,
typename B0Layout,
typename B1Layout,
typename D0Layout,
typename D1Layout,
typename ELayout>
int run_gemm_multi_abd_example_with_layouts(int argc,
char* argv[],
const A0Layout a0_layout = A0Layout{},
const A1Layout a1_layout = A1Layout{},
const B0Layout b0_layout = B0Layout{},
const B1Layout b1_layout = B1Layout{},
const D0Layout d0_layout = D0Layout{},
const D1Layout d1_layout = D1Layout{},
const ELayout e_layout = ELayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
using AElementWiseFn = ck_tile::element_wise::AddScale;
using BElementWiseFn = ck_tile::element_wise::AddScale;
using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply;
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
ck_tile::index_t StrideA0 = StrideA;
ck_tile::index_t StrideA1 = StrideA;
ck_tile::index_t StrideB0 = StrideB;
ck_tile::index_t StrideB1 = StrideB;
ck_tile::index_t StrideD0 = StrideD;
ck_tile::index_t StrideD1 = StrideD;
const int n_warmup = arg_parser.get_int("warmup");
const int n_repeat = arg_parser.get_int("repeat");
const int k_batch = arg_parser.get_int("kbatch");
StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout));
StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout));
StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout));
StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout));
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout)));
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout)));
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
ck_tile::FillUniformDistribution<A1DataType>{-1.f, 1.f}(a1_m_k_tesnor);
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
a1_m_k_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
b1_k_n_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
invoke_gemm_multi_abd<GemmConfig,
AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>(as_ptr_buf,
bs_ptr_buf,
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
strideAs,
strideBs,
strideDs,
StrideE,
n_warmup,
n_repeat,
k_batch);
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
a_m_k_host_ref_element_result.SetZero();
b_k_n_host_ref_element_result.SetZero();
e_m_n_host_ref.SetZero();
ck_tile::reference_gemm_multiple_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
{b0_k_n_tensors, b1_k_n_tensors},
{d0_m_n_tensors, d1_m_n_tensors},
a_m_k_host_ref_element_result,
b_k_n_host_ref_element_result,
e_m_n_host_ref);
bool pass{true};
if(arg_parser.get_int("v"))
{
const float max_accumulated_value =
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
pass &= ck_tile::check_err(e_m_n_device_result,
e_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< std::endl;
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename GemmConfig>
int run_multiple_abd_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string as_layout = arg_parser.get_str("as_layout");
const std::string bs_layout = arg_parser.get_str("bs_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(as_layout == "R" && bs_layout == "C")
{
return run_gemm_multi_abd_example_with_layouts<GemmConfig>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}

View File

@@ -21,6 +21,7 @@ add_subdirectory(18_flatmm)
add_subdirectory(19_gemm_multi_d)
add_subdirectory(20_grouped_convolution)
add_subdirectory(21_elementwise)
add_subdirectory(22_gemm_multi_abd)
add_subdirectory(35_batched_transpose)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)

View File

@@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
* @brief Load tile with elementwise function
*
* @note This function is a modification of the existing load function.
* It has been extended with two additional parameters: it takes a tuple as input
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,

View File

@@ -120,6 +120,116 @@ struct tile_window_with_static_distribution
return dst_tensor;
}
/**
* @brief Load tile with elementwise function
*
* @note Load tile with elementwise — during value loading, an
* elementwise function is executed for each A0, A1, … AN.
* The values A0, A1, … AN are read by the same thread. In this way, we
* reduce the amount of information loaded into the registers.
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
typename TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
bool_constant<oob_conditional_check>{});
},
number<sizeOfTuple>{});
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
ck_tile::apply(
[&](auto&&... t) {
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
t.template get_as<
typename Base::DataType>()[j / Traits::PackedSize]...);
},
idx_vec_value);
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
@@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE void move_tile_window(
tuple<tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>>& window,
const typename tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{
using T = tuple<tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>>;
static constexpr auto N = T::size();
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
}
template <typename TileWindowWithStaticDistributionType,
typename StepType,
typename std::enable_if_t<
is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
{
static constexpr auto N = TileWindowWithStaticDistributionType::size();
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
}
/**
* @brief This class provides description of tile windowed view on the device memory.
*

View File

@@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename AElementOp,
typename BElementOp,
typename CDElementOp,
typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
CK_TILE_HOST void
reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
HostTensor<ADataType>& a_m_k,
HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const CDElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto as_m_k_tuple =
generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
auto bs_k_n_tuple =
generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
auto ds_m_n_tuple =
generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
// Apply elementwise function to A
auto a_elementwise_fn = [&](auto i, auto j) {
ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
};
make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
// Apply elementwise function to B
auto b_elementwise_fn = [&](auto i, auto j) {
ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
};
make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
auto f_mk_kn_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_k_n(k, n);
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
ck_tile::apply(
[&](auto&&... t) {
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(t(m, n))...);
},
ds_m_n_tuple);
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename DsDataType,

View File

@@ -392,6 +392,23 @@ struct PassThrough
}
};
struct AddScale
{
template <typename E, typename... As>
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
{
// Start with the base value c
float result = ck_tile::type_convert<float>(0.0f);
// Add by each D parameter using fold expression
((result += ck_tile::type_convert<float>(as)), ...);
a = ck_tile::type_convert<E>(scale * result);
}
float scale = 1.0;
};
struct MultiDMultiply
{
template <typename E, typename C, typename... Ds>

View File

@@ -28,8 +28,8 @@ struct GetDataType<T>
using type = typename T::DataType; // Use T::ScaleN::DataType
};
template <typename ADataType_,
typename BDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
@@ -51,8 +51,8 @@ template <typename ADataType_,
bool TiledMMAPermuteN_ = false>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
@@ -83,12 +83,27 @@ template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A

View File

@@ -28,8 +28,8 @@ struct Default2DEpilogueProblem
static constexpr index_t NumDTensor = 0;
};
template <typename ADataType_,
typename BDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
@@ -53,8 +53,8 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
UseRawStore_,
MemoryOperation_>
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
@@ -157,14 +157,28 @@ struct Default2DEpilogue
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;

View File

@@ -31,6 +31,7 @@
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"

View File

@@ -90,10 +90,10 @@ struct BatchedGemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{

View File

@@ -89,7 +89,7 @@ struct GemmKernel
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
@@ -106,10 +106,10 @@ struct GemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;

View File

@@ -0,0 +1,193 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/// @brief The MultiABD GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating
/// kernel arguments object. It contain all necessary information required to build proper
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
/// configuration by stating all required information like M,N,K sizes and respective strides.
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
/// required).
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct GemmMultiABDHostArgs
{
CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: as_ptr(as_ptr_),
bs_ptr(bs_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_As(stride_As_),
stride_Bs(stride_Bs_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const std::array<const void*, NumATensor> as_ptr;
const std::array<const void*, NumBTensor> bs_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
const std::array<index_t, NumATensor> stride_As;
const std::array<index_t, NumBTensor> stride_Bs;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernelMultiABD
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
/// @brief Specify the layout configurations for A, B, E and D
using AsLayout = remove_cvref_t<typename GemmPipeline::AsLayout>;
using BsLayout = remove_cvref_t<typename GemmPipeline::BsLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
/// @brief ALayout and ADataType are expected to be a tuple, not a scalar.
static_assert(is_detected<is_tuple, AsLayout>::value &&
is_detected<is_tuple, AsDataType>::value,
"ALayout and ADataType must be a tuple.");
/// @brief BLayout and BDataType are expected to be a tuple, not a scalar.
static_assert(is_detected<is_tuple, BsLayout>::value &&
is_detected<is_tuple, BsDataType>::value,
"BLayout and BDataType must be a tuple.");
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"CLayout and EDataType must be a scalar.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&
is_detected<is_tuple, DsDataType>::value &&
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
"DsLayout and DsDataType must be tuples and must have the same size.");
/// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user."
static constexpr index_t NumATensor = AsDataType::size();
static constexpr index_t NumBTensor = BsDataType::size();
static constexpr index_t NumDTensor = DsDataType::size();
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel::GridSize(M, N, KBatch);
}
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const GemmMultiABDHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs) ->
typename UniversalGemmKernel::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B, and D.
return UniversalGemmKernel::MakeKernelArgs(
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>(hostArgs.as_ptr,
hostArgs.bs_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_As,
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E));
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
// Currently MultiABD kernel doesn't support k_batch > 1
if(kargs.k_batch > 1)
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
{
UniversalGemmKernel{}.template operator()(kargs);
}
};
} // namespace ck_tile

View File

@@ -95,7 +95,7 @@ struct GemmKernelMultiD
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
@@ -114,10 +114,10 @@ struct GemmKernelMultiD
!is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars.");
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"ELayout and EDataType must be scalars.");
"CLayout and EDataType must be scalars.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&

View File

@@ -120,10 +120,10 @@ struct GroupedGemmKernel
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
"C/CLayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
@@ -364,12 +364,8 @@ struct GroupedGemmKernel
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template

View File

@@ -157,23 +157,23 @@ struct UniversalGemmKernel
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr bool ADataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::ADataType>::value;
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
static constexpr bool BDataTypeIsTuple =
is_detected<is_tuple, typename GemmPipeline::BDataType>::value;
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
static constexpr bool DDataTypeIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
static constexpr bool ALayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::ALayout>::value;
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
static constexpr bool BLayoutIsTuple =
is_detected<is_tuple, typename GemmPipeline::BLayout>::value;
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
static constexpr bool DLayoutIsTuple =
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
using AsLayout = std::conditional_t<ALayoutIsTuple,
remove_cvref_t<typename GemmPipeline::ALayout>,
remove_cvref_t<typename GemmPipeline::AsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
using BsLayout = std::conditional_t<BLayoutIsTuple,
remove_cvref_t<typename GemmPipeline::BLayout>,
remove_cvref_t<typename GemmPipeline::BsLayout>,
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
using DsLayout = std::conditional_t<DLayoutIsTuple,
@@ -181,11 +181,11 @@ struct UniversalGemmKernel
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
using AsDataType = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::ADataType>,
remove_cvref_t<typename GemmPipeline::AsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
using BsDataType = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<typename GemmPipeline::BDataType>,
remove_cvref_t<typename GemmPipeline::BsDataType>,
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
using DsDataType =
@@ -193,9 +193,12 @@ struct UniversalGemmKernel
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
// Get the persistent kernel if the pipeline has it available
@@ -483,7 +486,7 @@ struct UniversalGemmKernel
bool DTesnorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, ELayout> == false)
if(std::is_same_v<DiLayout, CLayout> == false)
{
DTesnorIsValid = false;
}
@@ -529,7 +532,7 @@ struct UniversalGemmKernel
}
});
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
@@ -724,7 +727,7 @@ struct UniversalGemmKernel
// TODO: enable vector write for C in ColMajor
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
@@ -818,7 +821,7 @@ struct UniversalGemmKernel
// TODO vector write in for C in ColMajor
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
@@ -975,8 +978,8 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile =
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
@@ -1031,8 +1034,13 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}(
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
AElementWise{},
bs_block_window,
BElementWise{},
num_loop,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);

View File

@@ -11,12 +11,17 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmPipelineAgBgCrImplBase
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
@@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase
store_tile(lds_tile_window, block_tile_tmp);
}
template <typename DstTileWindow, typename SrcBlockTile>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile) const
{
store_tile(lds_tile_window, src_block_tile);
}
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
const SrcTileWindow& lds_tile_window,
@@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
Policy::template MakeADramTileDistribution<Problem>());
},
number<DramBlockWindowTmp::size()>{});
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeADramTileDistribution<Problem>());
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
Policy::template MakeBDramTileDistribution<Problem>());
},
number<DramBlockWindowTmp::size()>{});
return std::move(a_copy_dram_window);
}
template <typename DramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
nullptr>
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeBDramTileDistribution<Problem>());
return std::move(a_copy_dram_window);
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
a_dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeADramTileDistribution<Problem>());
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
// A LDS tile window for store
auto a_lds_shape = []() {
@@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase
const BLdsLoadTileDistr&,
const array<index_t, 2>& offset = {0, 0}) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile{}, XPerTile{}),
b_dram_block_window_tmp.get_window_origin() + offset,
Policy::template MakeBDramTileDistribution<Problem>());
// A DRAM tile window for load
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
// TODO: Do we really need those two tile windows???
// They're exactly same...

View File

@@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
@@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -449,17 +466,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
@@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// global read 1
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
@@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
if constexpr(is_a_col_major && !is_a_load_tr_v())
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
elementwise_As_res =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
block_sync_lds();
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);
@@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
@@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
* @note This is used by the persistent gemm kernel variants that don't determine
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
*/
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
@@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
@@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem);
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
}
/**
* @brief Quant operator(), single input: This function runs the pipeline by wrapping it with
* the tail handler.
*
* @note This is used by the persistent gemm kernel variants that don't determine
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
*/
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
/**
* @brief Quant operator(), single input: This function runs the pipeline using compile-time
* known hot loop and tail number.
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
* SplitK.
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
@@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
@@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -279,29 +296,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
@@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// global prefetch 0
// global read 0
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
@@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Generating a tuple with tile_windows for values A0, A1, ... AN
auto a_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_tile_windows, a_dram_tile_window_step);
// Generating a tuple with tile_windows for values B0, B1, ... BN
auto b_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// LDS write 0
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
// global read 1
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
block_sync_lds();
constexpr auto ALdsTileDistr =
@@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
if(HasHotLoop)
{
@@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(
b_copy_lds_window0, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
elementwise_As_res =
load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
@@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
}
block_sync_lds();
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
elementwise_As_res =
load_tile_with_elementwise(a_tile_windows, a_element_func);
move_tile_window(a_tile_windows, a_dram_tile_window_step);
elementwise_Bs_res =
load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
@@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
@@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
@@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
p_smem_1);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0,
p_smem_1);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
@@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
@@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0,
void* p_smem_1) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem_0,
p_smem_1);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem_0,
p_smem_1);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem_0,
p_smem_1);
}
};
} // namespace ck_tile

View File

@@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
@@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename AsDramBlockWindowTmp,
typename AElementFunction,
typename BDramBlockWindowTmp,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BsDramBlockWindowTmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* __restrict__ p_smem_0) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -209,14 +226,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
BGemmTile b_tile_0, b_tile_1;
// Register tile for A and B.
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTileDistr =
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
using BBlockTileDistr =
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
ABlockTile elementwise_As_res;
BBlockTile elementwise_Bs_res;
// Block GEMM
auto block_gemm = BlockGemm();
@@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
// define ping, pong steps here as lambda functions.
auto MemoryOpsStep = [&](auto idx) {
// Memory read half here.
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each
// A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each
// B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
if(idx == 0)
@@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
public:
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0) const
@@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
p_smem_0);
}
public:
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](auto& e, const BDataType& b) { e = b; },
num_loop,
p_smem_0);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_0) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem_0);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop,
void* __restrict__ p_smem_0) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem_0);
}
};
} // namespace ck_tile

View File

@@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
@@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -310,8 +327,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTileDistr =
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
using BBlockTileDistr =
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
@@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch
// global read 0
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
a_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
});
// main body
@@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
@@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
a_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
});
i += PrefetchStages;
@@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
a_block_tiles.get(number<prefetch_idx>{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
b_block_tiles.get(number<prefetch_idx>{}));
}
});
@@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -600,8 +639,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTileDistr =
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
using BBlockTileDistr =
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
@@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch
// global read 0
Base::GlobalPrefetch(
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
}
// Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
a_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
});
// main body
@@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
transpose_tile2d(
a_shuffle_tmp,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
@@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
transpose_tile2d(
b_shuffle_tmp,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
}
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window,
a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window,
b_dram_tile_window_step);
a_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
b_block_tiles.at(number<prefetch_idx>{}) =
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
});
i += PrefetchStages;
@@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
}
else
{
Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}),
a_element_func);
a_block_tiles.get(number<prefetch_idx>{}));
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
}
else
{
Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}),
b_element_func);
b_block_tiles.get(number<prefetch_idx>{}));
}
});
@@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
@@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
p_smem);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
@@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
@@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](auto& e, const ADataType& a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](auto& e, const ADataType& a) { e = a; },
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
a_element_func,
ck_tile::make_tuple(b_dram_block_window_tmp),
b_element_func,
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
has_hot_loop,
tail_number,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -15,14 +15,23 @@ namespace ck_tile {
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
@@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1
return Policy::template GetSmemSize<Problem>();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
@@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
auto as_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
auto bs_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
// B LDS tile window for store
auto b_copy_lds_window = make_tile_window(
@@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
// as input.
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
store_tile(a_copy_lds_window, elementwise_As_res);
}
// LDS write 0
@@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
}
@@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
block_sync_lds();
@@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, a_block_tile);
store_tile(a_copy_lds_window,
tile_elementwise_in(a_element_func, a_shuffle_tmp_loop));
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
}
else
{
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
store_tile(a_copy_lds_window, elementwise_As_res);
}
// LDS write i + 1
@@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
@@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
[](auto& e, const ADataType & a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType & b) { return b; },
[](auto& e, const BDataType & b) { e = b; },
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -15,30 +15,66 @@ namespace ck_tile {
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool Preshuffle = Problem::Preshuffle;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV2",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize));
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
@@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2
BPackedSize;
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
using ADramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
using BDramBlockWindowTmp =
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
@@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
auto as_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
as_copy_dram_window[number<0>{}].get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
auto bs_copy_dram_window = generate_tuple(
[&](auto idx) {
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
bs_copy_dram_window[number<0>{}].get_tile_distribution());
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
@@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
// Load tile — during value loading, an elementwise function is executed for each A0,
// A1, … AN. The values A0, A1, … AN are read by the same thread.
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// Load tile — during value loading, an elementwise function is executed for each B0,
// B1, … BN. The values B0, B1, … BN are read by the same thread.
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
store_tile(a_copy_lds_window, elementwise_As_res);
// global read 1
a_block_tile = load_tile(a_copy_dram_window);
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
store_tile(b_copy_lds_window, elementwise_Bs_res);
// global read 1
b_block_tile = load_tile(b_copy_dram_window);
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
}
index_t iCounter = num_loop - 2;
@@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
store_tile(a_copy_lds_window, elementwise_As_res);
// global read i + 2
a_block_tile = load_tile(a_copy_dram_window);
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
// LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
store_tile(b_copy_lds_window, elementwise_Bs_res);
// global read i + 2
b_block_tile = load_tile(b_copy_dram_window);
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
iCounter--;
@@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2
block_sync_lds();
// LDS write num_loop - 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
store_tile(a_copy_lds_window, elementwise_As_res);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
store_tile(b_copy_lds_window, elementwise_Bs_res);
block_sync_lds();
@@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
[](auto& e, const ADataType & a) { e = a; },
b_dram_block_window_tmp,
[](const BDataType & b) { return b; },
[](auto& e, const BDataType & b) { e = b; },
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
ck_tile::make_tuple(b_dram_block_window_tmp),
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -5,16 +5,19 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
typename ComputeDataType_ = AsDataType_,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
@@ -22,18 +25,49 @@ struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
using AElementWise = remove_cvref_t<AElementWise_>;
using BElementWise = remove_cvref_t<BElementWise_>;
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
remove_cvref_t<ComputeDataType_>,
remove_cvref_t<tuple<ComputeDataType_>>>;
using AsLayoutTuple = std::
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
using BsLayoutTuple = std::
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
@@ -66,7 +100,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
@@ -84,7 +118,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
@@ -125,7 +159,7 @@ struct GemmPipelineProblemBase
{
return VectorSizeA_;
}
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
@@ -140,7 +174,7 @@ struct GemmPipelineProblemBase
{
return VectorSizeB_;
}
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
@@ -161,35 +195,40 @@ struct GemmPipelineProblemBase
}();
};
// Alias for GemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
typename ComputeDataType_ = AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
using GemmPipelineProblem = GemmPipelineProblemBase<AsDataType_,
BsDataType_,
EDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_,
AElementWise_,
BElementWise_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename EDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
typename AElementWise_ = ck_tile::element_wise::PassThrough,
typename BElementWise_ = ck_tile::element_wise::PassThrough,
typename ComputeDataType_ = AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
@@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
using AElementWise = remove_cvref_t<AElementWise_>;
using BElementWise = remove_cvref_t<BElementWise_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
remove_cvref_t<ComputeDataType_>,
remove_cvref_t<tuple<ComputeDataType_>>>;
using AsLayoutTuple = std::
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
using BsLayoutTuple = std::
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;

View File

@@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
@@ -382,11 +385,14 @@ struct UniversalGemmBasePolicy
template <typename Problem, bool IsWave32Host = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem,
@@ -482,8 +488,6 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
@@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
@@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
@@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;

View File

@@ -10,8 +10,8 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
index_t NumWaveGroups_ = 1>
struct TileGemmTraits
@@ -23,9 +23,9 @@ struct TileGemmTraits
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
@@ -36,8 +36,8 @@ template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool DoubleSmemBuffer_,
typename ALayout_,
typename BLayout_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
bool TransposeC_ = false,
bool UseStructuredSparsity_ = false,
@@ -52,9 +52,9 @@ struct TileGemmUniversalTraits
static constexpr int _VectorSize = 16;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using AsLayout = AsLayout_;
using BsLayout = BsLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
@@ -67,8 +67,8 @@ template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool DoubleSmemBuffer_,
typename ALayout_,
typename BLayout_,
typename AsLayout_,
typename BsLayout_,
typename CLayout_,
bool TransposeC_ = false,
bool UseStructuredSparsity_ = false>
@@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
kPadN_,
kPadK_,
DoubleSmemBuffer_,
ALayout_,
BLayout_,
AsLayout_,
BsLayout_,
CLayout_,
TransposeC_,
UseStructuredSparsity_,

View File

@@ -37,15 +37,24 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
struct WeightPreshufflePipelineAGmemBGmemCRegV1
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockWeightPreshuffle =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
@@ -188,7 +197,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
}
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -455,7 +469,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
[[maybe_unused]] const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp[number<0>{}],
[](const ADataType & a) { return a; },
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
@@ -463,7 +503,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
[](auto& e, const ADataType & a) { e = a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem);

View File

@@ -53,14 +53,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using BlockWeightPreshuffle =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
@@ -502,7 +511,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
template <TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction>
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -1001,8 +1013,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
return c_block_tile;
}
// called from universal gemm kernel
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
[[maybe_unused]] const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
[[maybe_unused]] const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()<TailNum>(
a_dram_block_window_tmp[number<0>{}],
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp[number<0>{}],
num_loop,
p_smem_ping,
p_smem_pong);
}
// called from general gemm kernel
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
@@ -1019,9 +1060,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
// called from grouped gemm kernel
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_flat_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,

View File

@@ -44,6 +44,10 @@ struct TileGemmQuantTraits
using AQLayout = AQLayout_;
using BQLayout = BQLayout_;
// TODO: It should be replaced to single value
using AsLayout = ALayout_;
using BsLayout = BLayout_;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;

View File

@@ -5,6 +5,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(grouped_gemm_preshuffle)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_multi_abd)
add_subdirectory(gemm_streamk)
add_subdirectory(data_type)
add_subdirectory(container)

View File

@@ -0,0 +1,12 @@
# Currently ck_tile is only built on gfx9
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp)
add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp)
target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_multi_abd_util.hpp"
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue enabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);
#include "test_gemm_multi_abd_ut_cases_cshuffle.inc"

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_multi_abd_util.hpp"
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue disabled
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);
#include "test_gemm_multi_abd_ut_cases_default2d.inc"

View File

@@ -0,0 +1,211 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}

View File

@@ -0,0 +1,211 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x512)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}

View File

@@ -0,0 +1,500 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
struct AddScale
{
template <typename E, typename A0, typename A1>
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const
{
a = scale * (ck_tile::type_convert<float>(a0) + ck_tile::type_convert<float>(a1));
}
float scale = 1.0;
};
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
struct ElementWiseAddAdd
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename A0DataType,
typename B0DataType,
typename AccDataType,
typename EDataType,
typename D0DataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename Tuple>
class TestCkTileGemmMultiABD : public ::testing::Test
{
protected:
using A0Layout = std::tuple_element_t<0, Tuple>;
using A1Layout = std::tuple_element_t<1, Tuple>;
using B0Layout = std::tuple_element_t<2, Tuple>;
using B1Layout = std::tuple_element_t<3, Tuple>;
using D0Layout = std::tuple_element_t<4, Tuple>;
using D1Layout = std::tuple_element_t<5, Tuple>;
using ELayout = std::tuple_element_t<6, Tuple>;
using A0DataType = std::tuple_element_t<7, Tuple>;
using A1DataType = std::tuple_element_t<8, Tuple>;
using B0DataType = std::tuple_element_t<9, Tuple>;
using B1DataType = std::tuple_element_t<10, Tuple>;
using D0DataType = std::tuple_element_t<11, Tuple>;
using D1DataType = std::tuple_element_t<12, Tuple>;
using AccDataType = std::tuple_element_t<13, Tuple>;
using EDataType = std::tuple_element_t<14, Tuple>;
using AElementWiseFn = std::tuple_element_t<15, Tuple>;
using BElementWiseFn = std::tuple_element_t<16, Tuple>;
using CDElementWiseFn = std::tuple_element_t<17, Tuple>;
using UseCshuffleEpilog = std::tuple_element_t<18, Tuple>;
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename CDElementWiseFn = ck_tile::element_wise::PassThrough>
void invoke_gemm_multi_abd(const ck_tile::GemmMultiABDHostArgs<AsDataType::size(),
BsDataType::size(),
DsDataType::size()>& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
AsLayout,
BsLayout,
ELayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
BsDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
AElementWise,
BElementWise>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = std::
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
std::cout << "Run without SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
std::cout << "Run using SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
bool Run(const int M,
const int N,
const int K,
const int k_batch,
int StrideA0 = 0,
int StrideA1 = 0,
int StrideB0 = 0,
int StrideB1 = 0,
int StrideD0 = 0,
int StrideD1 = 0,
int StrideE = 0)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA0 = f_get_default_stride(M, K, StrideA0, A0Layout{});
StrideA1 = f_get_default_stride(M, K, StrideA1, A1Layout{});
StrideB0 = f_get_default_stride(K, N, StrideB0, B0Layout{});
StrideB1 = f_get_default_stride(K, N, StrideB1, B1Layout{});
StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{});
StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{});
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
f_host_tensor_descriptor(M, K, StrideA0, A0Layout{}));
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
f_host_tensor_descriptor(M, K, StrideA1, A1Layout{}));
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
f_host_tensor_descriptor(K, N, StrideB1, B1Layout{}));
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a1_m_k_tesnor);
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
a1_m_k_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
b1_k_n_dev_buf.GetDeviceBuffer()};
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>
args({as_ptr_buf,
bs_ptr_buf,
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
k_batch,
M,
N,
K,
strideAs,
strideBs,
strideDs,
StrideE});
invoke_gemm_multi_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementWiseFn,
BElementWiseFn,
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA0 =" << StrideA0 << " StrideA1 =" << StrideA1
<< " StrideB0 =" << StrideB0 << " StrideB1 =" << StrideB1
<< " StrideE =" << StrideE << " StrideD0 =" << StrideD0
<< " StrideD1 =" << StrideD1 << std::endl;
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
bool pass = true;
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
f_host_tensor_descriptor(M, K, StrideA0, A0Layout{}));
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
a_m_k_host_ref_element_result.SetZero();
b_k_n_host_ref_element_result.SetZero();
e_m_n_host_ref.SetZero();
ck_tile::reference_gemm_multiple_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
AElementWiseFn,
BElementWiseFn,
CDElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
{b0_k_n_tensors, b1_k_n_tensors},
{d0_m_n_tensors, d1_m_n_tensors},
a_m_k_host_ref_element_result,
b_k_n_host_ref_element_result,
e_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<A0DataType, B0DataType, AccDataType, EDataType, D0DataType>(
K, k_batch, max_accumulated_value);
pass = ck_tile::check_err(e_m_n_device_result,
e_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
return pass;
}
};