[CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919)

* switch to universal gemm for batched and grouped gemms

* added reviewer comments

* fixed grouped gemm tests

[ROCm/composable_kernel commit: 0e91d32c61]
This commit is contained in:
jakpiase
2025-03-20 11:17:04 +01:00
committed by GitHub
parent 8be61cfc9d
commit 9bf1c41338
13 changed files with 853 additions and 359 deletions

View File

@@ -18,16 +18,42 @@
template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
@@ -36,61 +62,232 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t K_Warp_Tile = 16;
using CodegenGemmShape =
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
auto kargs = Kernel::MakeKernelArgs(args);
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time{0};
if(!Kernel::IsSupportedArgument(kargs))
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(has_hot_loop)
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if(s.log_level_ > 0)
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
std::ostringstream err;
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
"got "
<< tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,6 +9,30 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType>
struct BatchedGemmTypeConfig;
@@ -32,19 +56,19 @@ using CDataType = Types::CDataType;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
arg_parser.insert("m", "512", "m dimension")
.insert("n", "1024", "n dimension")
.insert("k", "2048", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride")
.insert("batch_stride_c", "32768", "Batch C stride")
.insert("batch_count", "16", "Batch count")
.insert("batch_stride_a", "1048576", "Batch A stride")
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")

View File

@@ -185,7 +185,6 @@ int run_batched_gemm_example_with_layouts(int argc,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -16,85 +16,9 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
namespace {
struct GroupedGemmKernelParam
{
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 32;
static const ck_tile::index_t M_Warp = 2;
static const ck_tile::index_t N_Warp = 2;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 8;
};
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemmKernelParam::M_Tile,
GroupedGemmKernelParam::N_Tile,
GroupedGemmKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<ALayout, BLayout, CLayout>>;
}; // namespace
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
@@ -102,37 +26,265 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_)
{
using GroupedGemmKernel = ::Kernel<ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
ck_tile::hip_check_error(hipMemcpyWithStream(
p_workspace_,
arguments.data(),
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
hipMemcpyHostToDevice,
s.stream_id_));
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
if(s.log_level_ > 0)
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
const dim3 grids = Kernel::GridSize(gemm_descs);
constexpr dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
return ave_time;
};
if(has_hot_loop)
{
std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
<< "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GroupedGemmKernelParam::kBlockPerCu>(
GroupedGemmKernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
return ave_time;
}

View File

@@ -9,6 +9,30 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType>
struct GemmTypeConfig;
@@ -29,7 +53,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
auto create_args(int argc, char* argv[])
{
@@ -46,7 +70,7 @@ auto create_args(int argc, char* argv[])
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
.insert("group_count", "8", "group count.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -101,8 +101,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(256 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
@@ -169,7 +169,10 @@ int run_grouped_gemm_example_with_layouts(int argc,
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
// TODO Add support for kbatch > 1 in grouped gemm
static constexpr ck_tile::index_t k_batch = 1;
gemm_descs.push_back(
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);

View File

@@ -46,7 +46,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
{
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
using GemmKernelArgs = typename ck_tile::GemmKernelArgs;
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
@@ -65,7 +65,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on

View File

@@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem
index_t k_batch;
};
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
@@ -90,20 +104,6 @@ struct GemmKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t k_batch;
};
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
{
return GemmKernelArgs{hostArgs.a_ptr,

View File

@@ -11,24 +11,17 @@
namespace ck_tile {
struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
struct GemmTransKernelArg
{
CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_)
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
GemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = default;
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
private:
static constexpr index_t KBatch = 1;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
@@ -47,36 +40,22 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename Base::GemmKernelArgs;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
struct GemmTransKernelArg
{
GemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = default;
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
};
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
-> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
@@ -84,7 +63,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
__host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -95,7 +74,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
CK_TILE_HOST static auto MakeKargs(const std::vector<GemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;

View File

@@ -3,7 +3,7 @@
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 128;
constexpr int K = 128;
constexpr int N = 256;
constexpr int K = 512;
this->Run(M, N, K);
}

View File

@@ -28,17 +28,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -46,72 +38,144 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t K_Warp_Tile = 16;
using CodegenGemmShape =
constexpr bool DoubleSmemBuffer = false;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
float ave_time{0};
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;
using Kernel =
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
auto kargs = Kernel::MakeKernelArgs(args);
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
if(s.log_level_ > 0)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(has_hot_loop)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \""
<< tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:
void Run(const int M,
const int N,
const int K,
int StrideA = 128,
int StrideB = 128,
int StrideC = 128,
const int BatchStrideA = 32768,
const int BatchStrideB = 16384,
const int BatchStrideC = 32768,
const int BatchCount = 16)
int StrideA = 512,
int StrideB = 512,
int StrideC = 256,
const int BatchStrideA = 131072,
const int BatchStrideB = 131072,
const int BatchStrideC = 65536,
const int BatchCount = 8)
{
using namespace ck_tile::literals;

View File

@@ -2,7 +2,7 @@
TYPED_TEST(TestCkTileGroupedGemm, Basic)
{
const int group_count = 16;
const int group_count = 8;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
@@ -13,8 +13,8 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(256 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);

View File

@@ -44,65 +44,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 8;
};
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
GroupedGemKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline<ALayout, BLayout, CLayout>::BlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<ALayout, BLayout, CLayout>>;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
@@ -110,35 +55,140 @@ class TestCkTileGroupedGemm : public ::testing::Test
const ck_tile::stream_config& s,
void* p_workspace_)
{
using GroupedGemmKernel = Kernel<ALayout, BLayout, CLayout>;
constexpr bool DoubleSmemBuffer = false;
constexpr bool TransposeC = false;
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
GroupedGemKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
ck_tile::hip_check_error(hipMemcpyWithStream(
p_workspace_,
arguments.data(),
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
hipMemcpyHostToDevice,
s.stream_id_));
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
if(s.log_level_ > 0)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile;
const ck_tile::index_t K_split =
(gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
const dim3 grids = Kernel::GridSize(gemm_descs);
constexpr dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
return ave_time;
};
if(has_hot_loop)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \""
<< tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
GroupedGemmKernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
}
public:
@@ -243,12 +293,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
// TODO add support for kbatch > 1
static constexpr ck_tile::index_t k_batch = 1;
gemm_descs.push_back(
{p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs));
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());