mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Multiple-ABD GEMM example (#2788)
* Multi ABD - initial commit * Clang-foramt fix * block gemm, unify the name of CDataType * Apply chnages to mem-pipeline * Rollback prefix for DType and Layout * Gemm Kernel Basic, rename * WMMA config * Grouped GEMM * Clang-format * Dropout, name * Review v2 * Move element_wise fn to unnary, remov old ones fn * clang-format * Fix issue review * WP operator adjust to universal gemm * v2 prepare * Remove unused comment * Remove vectorsize * Rollback * Adjust pipeline for abd * Shuffle argument * CI-fail fix quant * Fix ag_br pipeline * Failing tests * Typo * Single argument support
This commit is contained in:
@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
* Added support for Multiple D GEMM
|
||||
* Added support for Multiple ABD GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
|
||||
1
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
1
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp)
|
||||
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
#Multiple ABD GEMM
|
||||
|
||||
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_abd_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-as_layout Tensor A layout (default:R)
|
||||
-bs_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_as Tensor A strides - (Default: 0)
|
||||
-stride_bs Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default: 1)
|
||||
```
|
||||
184
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
184
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
@@ -0,0 +1,184 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_abd_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
186
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
186
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
@@ -0,0 +1,186 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
using A0DataType = ck_tile::half_t;
|
||||
using A1DataType = ck_tile::half_t;
|
||||
|
||||
using B0DataType = ck_tile::half_t;
|
||||
using B1DataType = ck_tile::half_t;
|
||||
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
|
||||
using EDataType = ck_tile::half_t;
|
||||
|
||||
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
|
||||
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
struct GemmConfigMemory
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("as_layout", "R", "As tensor data layout - Row by default")
|
||||
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_as", "0", "Tensor A stride")
|
||||
.insert("stride_bs", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
using gemm_multi_abd_kargs =
|
||||
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>;
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise,
|
||||
typename BElementWise,
|
||||
typename CDEElementWise>
|
||||
float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,311 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
|
||||
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
|
||||
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf,
|
||||
bs_k_n_dev_buf,
|
||||
ds_m_n_dev_buf,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
float ave_time = gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWise,
|
||||
BElementWise,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm Multiple-ABD"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
num_btype +=
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0]
|
||||
<< " StrideE = " << StrideE << "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename A0Layout,
|
||||
typename A1Layout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_gemm_multi_abd_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const A0Layout a0_layout = A0Layout{},
|
||||
const A1Layout a1_layout = A1Layout{},
|
||||
const B0Layout b0_layout = B0Layout{},
|
||||
const B1Layout b1_layout = B1Layout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
using AElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using BElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply;
|
||||
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
|
||||
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideA0 = StrideA;
|
||||
ck_tile::index_t StrideA1 = StrideA;
|
||||
|
||||
ck_tile::index_t StrideB0 = StrideB;
|
||||
ck_tile::index_t StrideB1 = StrideB;
|
||||
|
||||
ck_tile::index_t StrideD0 = StrideD;
|
||||
ck_tile::index_t StrideD1 = StrideD;
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout));
|
||||
StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout));
|
||||
|
||||
StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout));
|
||||
StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout));
|
||||
|
||||
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
|
||||
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
|
||||
|
||||
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout)));
|
||||
|
||||
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout)));
|
||||
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<A1DataType>{-1.f, 1.f}(a1_m_k_tesnor);
|
||||
|
||||
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
|
||||
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
|
||||
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
|
||||
|
||||
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
|
||||
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
|
||||
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a1_m_k_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b1_k_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
|
||||
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>(as_ptr_buf,
|
||||
bs_ptr_buf,
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
strideAs,
|
||||
strideBs,
|
||||
strideDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
a_m_k_host_ref_element_result.SetZero();
|
||||
b_k_n_host_ref_element_result.SetZero();
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
|
||||
{b0_k_n_tensors, b1_k_n_tensors},
|
||||
{d0_m_n_tensors, d1_m_n_tensors},
|
||||
a_m_k_host_ref_element_result,
|
||||
b_k_n_host_ref_element_result,
|
||||
e_m_n_host_ref);
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v"))
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
pass &= ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_multiple_abd_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string as_layout = arg_parser.get_str("as_layout");
|
||||
const std::string bs_layout = arg_parser.get_str("bs_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(as_layout == "R" && bs_layout == "C")
|
||||
{
|
||||
return run_gemm_multi_abd_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
@@ -21,6 +21,7 @@ add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(21_elementwise)
|
||||
add_subdirectory(22_gemm_multi_abd)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
|
||||
@@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note This function is a modification of the existing load function.
|
||||
* It has been extended with two additional parameters: it takes a tuple as input
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should works with unknow number of params
|
||||
// Load element_wise API works only when the input typle is a tuple-tyupe
|
||||
return tile_window[number<0>{}].load(
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
@@ -120,6 +120,116 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note Load tile with elementwise — during value loading, an
|
||||
* elementwise function is executed for each A0, A1, … AN.
|
||||
* The values A0, A1, … AN are read by the same thread. In this way, we
|
||||
* reduce the amount of information loaded into the registers.
|
||||
* The same thread, during vectorized reading, accesses the same set of
|
||||
* data from A0, A1, A2, … AN.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_window,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
constexpr auto sizeOfTuple = TileWindow_::size();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const auto idx_vec_value = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return tile_window[number<jj>{}]
|
||||
.get_bottom_tensor_view()
|
||||
.template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
},
|
||||
number<sizeOfTuple>{});
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
|
||||
t.template get_as<
|
||||
typename Base::DataType>()[j / Traits::PackedSize]...);
|
||||
},
|
||||
idx_vec_value);
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
@@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
using T = tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>;
|
||||
|
||||
static constexpr auto N = T::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
template <typename TileWindowWithStaticDistributionType,
|
||||
typename StepType,
|
||||
typename std::enable_if_t<
|
||||
is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
|
||||
CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
|
||||
{
|
||||
static constexpr auto N = TileWindowWithStaticDistributionType::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
|
||||
@@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDElementOp,
|
||||
typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
|
||||
typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
|
||||
const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<ADataType>& a_m_k,
|
||||
HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const CDElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto as_m_k_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
|
||||
|
||||
auto bs_k_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
|
||||
|
||||
auto ds_m_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
|
||||
|
||||
// Apply elementwise function to A
|
||||
auto a_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
|
||||
|
||||
// Apply elementwise function to B
|
||||
auto b_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(t(m, n))...);
|
||||
},
|
||||
ds_m_n_tuple);
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
|
||||
@@ -392,6 +392,23 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
template <typename E, typename... As>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
|
||||
{
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(0.0f);
|
||||
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(as)), ...);
|
||||
|
||||
a = ck_tile::type_convert<E>(scale * result);
|
||||
}
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
|
||||
@@ -28,8 +28,8 @@ struct GetDataType<T>
|
||||
using type = typename T::DataType; // Use T::ScaleN::DataType
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
@@ -51,8 +51,8 @@ template <typename ADataType_,
|
||||
bool TiledMMAPermuteN_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
@@ -83,12 +83,27 @@ template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
|
||||
@@ -28,8 +28,8 @@ struct Default2DEpilogueProblem
|
||||
static constexpr index_t NumDTensor = 0;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
@@ -53,8 +53,8 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
@@ -157,14 +157,28 @@ struct Default2DEpilogue
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
|
||||
@@ -90,10 +90,10 @@ struct BatchedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
|
||||
@@ -89,7 +89,7 @@ struct GemmKernel
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -106,10 +106,10 @@ struct GemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
static constexpr index_t NumATensor = 1;
|
||||
static constexpr index_t NumBTensor = 1;
|
||||
|
||||
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The MultiABD GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating
|
||||
/// kernel arguments object. It contain all necessary information required to build proper
|
||||
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
|
||||
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
|
||||
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
|
||||
/// required).
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct GemmMultiABDHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_As(stride_As_),
|
||||
stride_Bs(stride_Bs_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<const void*, NumATensor> as_ptr;
|
||||
const std::array<const void*, NumBTensor> bs_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const std::array<index_t, NumATensor> stride_As;
|
||||
const std::array<index_t, NumBTensor> stride_Bs;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernelMultiABD
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using AsLayout = remove_cvref_t<typename GemmPipeline::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename GemmPipeline::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, AsLayout>::value &&
|
||||
is_detected<is_tuple, AsDataType>::value,
|
||||
"ALayout and ADataType must be a tuple.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, BsLayout>::value &&
|
||||
is_detected<is_tuple, BsDataType>::value,
|
||||
"BLayout and BDataType must be a tuple.");
|
||||
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"CLayout and EDataType must be a scalar.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
is_detected<is_tuple, DsDataType>::value &&
|
||||
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
|
||||
"DsLayout and DsDataType must be tuples and must have the same size.");
|
||||
|
||||
/// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user."
|
||||
static constexpr index_t NumATensor = AsDataType::size();
|
||||
static constexpr index_t NumBTensor = BsDataType::size();
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
return UniversalGemmKernel::GetName();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const GemmMultiABDHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs) ->
|
||||
typename UniversalGemmKernel::KernelArgs
|
||||
{
|
||||
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
||||
/// matrices A, B, and D.
|
||||
return UniversalGemmKernel::MakeKernelArgs(
|
||||
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>(hostArgs.as_ptr,
|
||||
hostArgs.bs_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_As,
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
// Currently MultiABD kernel doesn't support k_batch > 1
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
||||
{
|
||||
UniversalGemmKernel{}.template operator()(kargs);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -95,7 +95,7 @@ struct GemmKernelMultiD
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
@@ -114,10 +114,10 @@ struct GemmKernelMultiD
|
||||
!is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"ELayout and EDataType must be scalars.");
|
||||
"CLayout and EDataType must be scalars.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
|
||||
@@ -120,10 +120,10 @@ struct GroupedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
@@ -364,12 +364,8 @@ struct GroupedGemmKernel
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
|
||||
@@ -157,23 +157,23 @@ struct UniversalGemmKernel
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ADataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BDataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
|
||||
static constexpr bool DDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
|
||||
static constexpr bool ALayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ALayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BLayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
|
||||
static constexpr bool DLayoutIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
|
||||
|
||||
using AsLayout = std::conditional_t<ALayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ALayout>,
|
||||
remove_cvref_t<typename GemmPipeline::AsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
|
||||
using BsLayout = std::conditional_t<BLayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BLayout>,
|
||||
remove_cvref_t<typename GemmPipeline::BsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
|
||||
|
||||
using DsLayout = std::conditional_t<DLayoutIsTuple,
|
||||
@@ -181,11 +181,11 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
|
||||
|
||||
using AsDataType = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ADataType>,
|
||||
remove_cvref_t<typename GemmPipeline::AsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
|
||||
|
||||
using BsDataType = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BDataType>,
|
||||
remove_cvref_t<typename GemmPipeline::BsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
|
||||
|
||||
using DsDataType =
|
||||
@@ -193,9 +193,12 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
@@ -483,7 +486,7 @@ struct UniversalGemmKernel
|
||||
bool DTesnorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, ELayout> == false)
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
@@ -529,7 +532,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
@@ -724,7 +727,7 @@ struct UniversalGemmKernel
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
@@ -818,7 +821,7 @@ struct UniversalGemmKernel
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -975,8 +978,8 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
@@ -1031,8 +1034,13 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
BElementWise{},
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
@@ -11,12 +11,17 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmPipelineAgBgCrImplBase
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
@@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase
|
||||
store_tile(lds_tile_window, block_tile_tmp);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile) const
|
||||
{
|
||||
store_tile(lds_tile_window, src_block_tile);
|
||||
}
|
||||
|
||||
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
|
||||
const SrcTileWindow& lds_tile_window,
|
||||
@@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase
|
||||
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
a_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_lds_shape = []() {
|
||||
@@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
const BLdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
b_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
// A DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// TODO: Do we really need those two tile windows???
|
||||
// They're exactly same...
|
||||
|
||||
@@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
@@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -449,17 +466,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
@@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// global read 1
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
@@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline by wrapping it with
|
||||
* the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline using compile-time
|
||||
* known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
@@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -279,29 +296,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// A register tile for global load
|
||||
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
|
||||
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
@@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
// global prefetch 0
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
////////////// LDS desc, window & register /////////////////
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
@@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Generating a tuple with tile_windows for values A0, A1, ... AN
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
// Generating a tuple with tile_windows for values B0, B1, ... BN
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
// global read 1
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto ALdsTileDistr =
|
||||
@@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
if(HasHotLoop)
|
||||
{
|
||||
@@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
HotLoopScheduler();
|
||||
@@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
HotLoopScheduler();
|
||||
@@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
@@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
@@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
@@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -209,14 +226,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
ABlockTile elementwise_As_res;
|
||||
BBlockTile elementwise_Bs_res;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
@@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
@@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
@@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
@@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -310,8 +327,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
@@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
@@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -600,8 +639,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
@@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
@@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -15,14 +15,23 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
@@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
@@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto as_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
auto bs_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
@@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
@@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, a_block_tile);
|
||||
store_tile(a_copy_lds_window,
|
||||
tile_elementwise_in(a_element_func, a_shuffle_tmp_loop));
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
@@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, b_block_tile);
|
||||
store_tile(b_copy_lds_window,
|
||||
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
@@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType & b) { return b; },
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -15,30 +15,66 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize));
|
||||
// clang-format on
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
@@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
@@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto as_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
as_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
auto bs_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
bs_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
@@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write 0
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
@@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read i + 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write i + 1
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read i + 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
iCounter--;
|
||||
|
||||
@@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType & b) { return b; },
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,16 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
@@ -22,18 +25,49 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
@@ -66,7 +100,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
@@ -84,7 +118,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
@@ -125,7 +159,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
return VectorSizeA_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
@@ -140,7 +174,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
return VectorSizeB_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
@@ -161,35 +195,40 @@ struct GemmPipelineProblemBase
|
||||
}();
|
||||
};
|
||||
|
||||
// Alias for GemmPipelineProblem
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<AsDataType_,
|
||||
BsDataType_,
|
||||
EDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_,
|
||||
AElementWise_,
|
||||
BElementWise_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
@@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
|
||||
@@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
@@ -382,11 +385,14 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
@@ -482,8 +488,6 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
@@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
|
||||
@@ -10,8 +10,8 @@ namespace ck_tile {
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmTraits
|
||||
@@ -23,9 +23,9 @@ struct TileGemmTraits
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
@@ -36,8 +36,8 @@ template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
@@ -52,9 +52,9 @@ struct TileGemmUniversalTraits
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
@@ -67,8 +67,8 @@ template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
@@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
AsLayout_,
|
||||
BsLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
|
||||
@@ -37,15 +37,24 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
@@ -188,7 +197,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -455,7 +469,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
@@ -463,7 +503,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
|
||||
@@ -53,14 +53,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
@@ -502,7 +511,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
template <TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction>
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -1001,8 +1013,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// called from universal gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
@@ -1019,9 +1060,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
|
||||
@@ -44,6 +44,10 @@ struct TileGemmQuantTraits
|
||||
using AQLayout = AQLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
// TODO: It should be replaced to single value
|
||||
using AsLayout = ALayout_;
|
||||
using BsLayout = BLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
@@ -5,6 +5,7 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(grouped_gemm_preshuffle)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_multi_abd)
|
||||
add_subdirectory(gemm_streamk)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(container)
|
||||
|
||||
12
test/ck_tile/gemm_multi_abd/CMakeLists.txt
Normal file
12
test/ck_tile/gemm_multi_abd/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp)
|
||||
add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp)
|
||||
target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
40
test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp
Normal file
40
test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_multi_abd_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue enabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);
|
||||
|
||||
#include "test_gemm_multi_abd_ut_cases_cshuffle.inc"
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_multi_abd_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// Has cshuffle epilogue disabled
|
||||
// A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>,
|
||||
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>,
|
||||
std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes);
|
||||
|
||||
#include "test_gemm_multi_abd_ut_cases_default2d.inc"
|
||||
@@ -0,0 +1,211 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 256;
|
||||
constexpr int kBatch = 1;
|
||||
|
||||
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x512)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x512x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x256x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x256x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x768x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x1280x256)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x1280x256)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 1280;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_768x512x256)
|
||||
{
|
||||
constexpr int M = 768;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x512x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x256x256)
|
||||
{
|
||||
constexpr int M = 1280;
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
constexpr int kBatch = 2;
|
||||
|
||||
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
|
||||
}
|
||||
500
test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp
Normal file
500
test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp
Normal file
@@ -0,0 +1,500 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
template <typename E, typename A0, typename A1>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const
|
||||
{
|
||||
a = scale * (ck_tile::type_convert<float>(a0) + ck_tile::type_convert<float>(a1));
|
||||
}
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
struct ElementWiseAddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename D0DataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmMultiABD : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using A0Layout = std::tuple_element_t<0, Tuple>;
|
||||
using A1Layout = std::tuple_element_t<1, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<2, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<3, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<4, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<5, Tuple>;
|
||||
using ELayout = std::tuple_element_t<6, Tuple>;
|
||||
using A0DataType = std::tuple_element_t<7, Tuple>;
|
||||
using A1DataType = std::tuple_element_t<8, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<9, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<10, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<11, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<12, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<13, Tuple>;
|
||||
using EDataType = std::tuple_element_t<14, Tuple>;
|
||||
using AElementWiseFn = std::tuple_element_t<15, Tuple>;
|
||||
using BElementWiseFn = std::tuple_element_t<16, Tuple>;
|
||||
using CDElementWiseFn = std::tuple_element_t<17, Tuple>;
|
||||
using UseCshuffleEpilog = std::tuple_element_t<18, Tuple>;
|
||||
|
||||
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
|
||||
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
|
||||
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
|
||||
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDElementWiseFn = ck_tile::element_wise::PassThrough>
|
||||
void invoke_gemm_multi_abd(const ck_tile::GemmMultiABDHostArgs<AsDataType::size(),
|
||||
BsDataType::size(),
|
||||
DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
kPadM,
|
||||
kPadN,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
|
||||
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using GemmEpilogue = std::
|
||||
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
std::cout << "Run without SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run using SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
public:
|
||||
bool Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int k_batch,
|
||||
int StrideA0 = 0,
|
||||
int StrideA1 = 0,
|
||||
int StrideB0 = 0,
|
||||
int StrideB1 = 0,
|
||||
int StrideD0 = 0,
|
||||
int StrideD1 = 0,
|
||||
int StrideE = 0)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA0 = f_get_default_stride(M, K, StrideA0, A0Layout{});
|
||||
StrideA1 = f_get_default_stride(M, K, StrideA1, A1Layout{});
|
||||
|
||||
StrideB0 = f_get_default_stride(K, N, StrideB0, B0Layout{});
|
||||
StrideB1 = f_get_default_stride(K, N, StrideB1, B1Layout{});
|
||||
|
||||
StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{});
|
||||
StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{});
|
||||
|
||||
StrideE = f_get_default_stride(M, N, StrideE, ELayout{});
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
|
||||
f_host_tensor_descriptor(M, K, StrideA0, A0Layout{}));
|
||||
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
|
||||
f_host_tensor_descriptor(M, K, StrideA1, A1Layout{}));
|
||||
|
||||
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
|
||||
f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
|
||||
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
|
||||
f_host_tensor_descriptor(K, N, StrideB1, B1Layout{}));
|
||||
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a1_m_k_tesnor);
|
||||
|
||||
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
|
||||
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
|
||||
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
|
||||
|
||||
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
|
||||
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
|
||||
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a1_m_k_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b1_k_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
|
||||
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
|
||||
|
||||
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>
|
||||
args({as_ptr_buf,
|
||||
bs_ptr_buf,
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
strideAs,
|
||||
strideBs,
|
||||
strideDs,
|
||||
StrideE});
|
||||
|
||||
invoke_gemm_multi_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
|
||||
|
||||
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA0 =" << StrideA0 << " StrideA1 =" << StrideA1
|
||||
<< " StrideB0 =" << StrideB0 << " StrideB1 =" << StrideB1
|
||||
<< " StrideE =" << StrideE << " StrideD0 =" << StrideD0
|
||||
<< " StrideD1 =" << StrideD1 << std::endl;
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
bool pass = true;
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
|
||||
f_host_tensor_descriptor(M, K, StrideA0, A0Layout{}));
|
||||
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
|
||||
f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
a_m_k_host_ref_element_result.SetZero();
|
||||
b_k_n_host_ref_element_result.SetZero();
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
|
||||
{b0_k_n_tensors, b1_k_n_tensors},
|
||||
{d0_m_n_tensors, d1_m_n_tensors},
|
||||
a_m_k_host_ref_element_result,
|
||||
b_k_n_host_ref_element_result,
|
||||
e_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<A0DataType, B0DataType, AccDataType, EDataType, D0DataType>(
|
||||
K, k_batch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user