[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -99,62 +99,47 @@ class TestCkTileBatchedGemm : public ::testing::Test
scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:

View File

@@ -120,8 +120,8 @@ using SimpleCShuffleEpilogueProblem =
MPerXdl,
NPerXdl,
KPerXdl,
false, // isCTransposed,
memory_operation_enum::set>;
false // isCTransposed
>;
template <typename Problem, index_t M, index_t N>
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)

View File

@@ -182,74 +182,58 @@ class TestCkTileGemmPipeline : public ::testing::Test
using GemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
1, /*kNumWaveGroups_*/
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
1, /*kNumWaveGroups_*/
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
<< blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
const dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
<< ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:

View File

@@ -356,8 +356,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
transpose_c>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
@@ -641,7 +640,6 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set,
1,
false,
1,
@@ -949,7 +947,6 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
@@ -1174,8 +1171,7 @@ class TestCkTileGemmRowColQuant
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
transpose_c>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
@@ -1389,8 +1385,7 @@ class TestCkTileGemmTensorQuant
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
transpose_c>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,

View File

@@ -186,88 +186,69 @@ class TestCkTileGemmMultiABD : public ::testing::Test
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true>>;
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDElementWiseFn,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue =
std::conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using GemmEpilogue = std::
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
std::cout << "Run without SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
std::cout << "Run using SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:

View File

@@ -170,88 +170,69 @@ class TestCkTileGemmMultiD : public ::testing::Test
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true>>;
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue =
std::conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using GemmEpilogue = std::
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
if(!Kernel::IsSupportedArgument(kargs))
{
std::cout << "Run without SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
else
if(s.log_level_ > 0)
{
std::cout << "Run using SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:

View File

@@ -105,71 +105,60 @@ class TestCkTileStreamK : public ::testing::Test
NumWaveGroup,
preshuffle>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// For initial testing, we will just test with one pipeline.
// More extensive testing is coming later and will test other pipelines.
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// For initial testing, we will just test with one pipeline.
// More extensive testing is coming later and will test other pipelines.
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
if(!Kernel::IsSupportedArgument(kargs))
{
EXPECT_TRUE(false);
}
if(!Kernel::IsSupportedArgument(kargs))
{
EXPECT_TRUE(false);
}
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
dim3 block_dims = Kernel::BlockSize();
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
dim3 block_dims = Kernel::BlockSize();
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
return kargs.tile_partitioner.estimate_num_wgs_per_tile();
};
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the same
// output tile in the C tensor, so we must atomic add
// the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
return kargs.tile_partitioner.estimate_num_wgs_per_tile();
}
public:

View File

@@ -180,68 +180,52 @@ class TestCkTileGemmPipeline : public ::testing::Test
using GemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
<< blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
dim3 grids;
if constexpr(Persistent)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
<< ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
public:

View File

@@ -42,8 +42,7 @@ template <typename PrecType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
memory_operation_enum MemOp = memory_operation_enum::set,
index_t NDimSpatial = 2>
index_t NDimSpatial = 2>
struct BuildKernel
{
using GemmShape = TileGemmShape<
@@ -123,7 +122,6 @@ struct BuildKernel
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvTraits::FixedGemmParams::TransposeC,
MemOp,
ConvConfig::NumWaveGroups,
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeC>;
@@ -212,26 +210,6 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne)
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne)
{
using Kernel = typename BuildKernel<half_t,
TestConvConfig,
tensor_layout::convolution::NHWGC,
tensor_layout::convolution::GKYXC,
tensor_layout::convolution::NHWGK,
memory_operation_enum::atomic_add>::type;
// k_batch = 1 should fail with atomic_add
auto host_args_kbatch_1 = create_2d_host_args(1);
auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1));
// k_batch = 2 should pass
auto host_args_kbatch_2 = create_2d_host_args(2);
auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
}
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
{
using Kernel = typename BuildKernel<half_t,

View File

@@ -109,78 +109,59 @@ class TestCkTileGroupedGemm : public ::testing::Test
scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
if(s.log_level_ > 0)
{
std::cout << "Run without SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
std::cout << "Run using SplitK" << std::endl;
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer = false;
@@ -212,50 +193,47 @@ class TestCkTileGroupedGemm : public ::testing::Test
CLayout,
TransposeC>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
@@ -264,19 +242,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
}
auto calculate_rtol_atol(const ck_tile::index_t K,
@@ -422,8 +387,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
{
// Generate kernel arguments
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = gemm_descs[0].k_batch > 1;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
for(const auto& arg : gemm_descs)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
@@ -448,10 +412,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
stream.stream_id_));
#if CK_TILE_USE_WMMA
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
stream, group_count, kargs_ptr);
#else
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
stream, group_count, kargs_ptr);
#endif
}
else

View File

@@ -96,7 +96,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
const ck_tile::stream_config& s,
void* kargs_ptr)
{
EXPECT_TRUE(gemm_descs[0].k_batch == 1);
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Config::M_Tile_, Config::N_Tile_, Config::K_Tile_>,
ck_tile::sequence<Config::M_Warp_, Config::N_Warp_, Config::K_Warp_>,
@@ -134,74 +134,56 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
MultiplyMultiply,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Config::M_Warp_,
Config::N_Warp_,
Config::M_Warp_Tile_,
Config::N_Warp_Tile_,
Config::K_Warp_Tile_,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
MultiplyMultiply,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Config::M_Warp_,
Config::N_Warp_,
Config::M_Warp_Tile_,
Config::N_Warp_Tile_,
Config::K_Warp_Tile_,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<Config::BlockPerCu_>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
if(s.log_level_ > 0)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<Config::BlockPerCu_>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Config::M_Tile_, Config::N_Tile_, Config::K_Tile_>,
@@ -218,78 +200,58 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
BLayout,
ELayout>;
float ave_time{0};
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Config::Scheduler_>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmPipeline = std::conditional_t<
Config::Pipeline_ == (PipelineType::Memory),
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>,
std::conditional_t<Config::Pipeline_ == (PipelineType::CompV3),
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
MultiplyMultiply,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Config::M_Warp_,
Config::N_Warp_,
Config::M_Warp_Tile_,
Config::N_Warp_Tile_,
Config::K_Warp_Tile_,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Config::Scheduler_>;
using GemmPipeline = std::conditional_t<
Config::Pipeline_ == (PipelineType::Memory),
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>,
std::conditional_t<Config::Pipeline_ == (PipelineType::CompV3),
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
MultiplyMultiply,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Config::M_Warp_,
Config::N_Warp_,
Config::M_Warp_Tile_,
Config::N_Warp_Tile_,
Config::K_Warp_Tile_,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<Config::BlockPerCu_>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(!splitk)
if(s.log_level_ > 0)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<Config::BlockPerCu_>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}
public:
@@ -445,8 +407,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
if constexpr(Config::Persistent_)
{
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = gemm_descs[0].k_batch > 1;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
for(const auto& arg : gemm_descs)
{
kargs.emplace_back(
@@ -471,7 +432,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test
hipMemcpyHostToDevice,
stream.stream_id_));
invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk);
invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr);
}
else
{

View File

@@ -127,59 +127,44 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
if(gemm_descs[0].k_batch == 1)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
private:
@@ -226,59 +211,45 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
ck_tile::GemmPipelineScheduler::Default>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(gemm_descs[0].k_batch == 1)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
}
struct BShuffleGemmConfig

View File

@@ -148,10 +148,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
@@ -217,8 +216,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
@@ -287,99 +285,92 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType,
AQLayout,
BQLayout,
TransposeC,
DoubleSmemBuffer,
Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType,
AQLayout,
BQLayout,
TransposeC,
DoubleSmemBuffer,
Persistent>;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
TransposeC,
BDataType,
scheduler>>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
TransposeC,
BDataType,
scheduler>>;
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
@@ -388,10 +379,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
template <typename Layout>