mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Grouped GEMM tile loop (#2146)
* Add trait to use a persistent kernel and split the entrypoints in grouped gemm * Some helper functions for persistent kernel case * Get max occupancy grid using device properties * Implement tile loop in main entry point to grouped gemm * Enable GridSize() on device * Handle offset tile index using real current block index * Add persistent kernel choice to grouped gemm example * Use a for-loop for iterating over the group * Reduce VGPR spills by early-exit * Enable persistent kernel choice in grouped_gemm example * Add persistent kernel option to grouped_gemm test * Fix formatting with remod.py * Remove GridUpdateBlocks as blocks are now iteratively computed * Add comment about VGPR spilling * Fix formatting * Use CK_TILE_HOST instead of __host__ * Enable all Row/Col combinations in grouped gemm unit test * Add some KBatch=2 cases to grouped gemm tests * Fix SplitK for grouped gemm * Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm * Add type traits * Split examples to regular and tileloop * Formatting * Use hipExtStreamGetCUMask to get current active CUs for the given stream * Align test and example kernel config, and disable validation for splitk repeats * Remove debug options from CMakeLists.txt * Separate the code paths for persistent/non-persistent in test * Fix formatting * Address review comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
|
||||
add_executable(tile_example_grouped_gemm_tileloop EXCLUDE_FROM_ALL grouped_gemm_tileloop.cpp)
|
||||
|
||||
@@ -16,15 +16,10 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -114,70 +109,76 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
@@ -317,4 +318,5 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
constexpr bool Persistent = false;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
|
||||
@@ -70,14 +70,25 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.");
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_);
|
||||
void* kargs_ptr);
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
|
||||
174
example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp
Normal file
174
example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(!splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = true;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
@@ -30,20 +30,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
float ave_time = 0;
|
||||
if constexpr(!Persistent)
|
||||
{
|
||||
// Regular version of grouped gemm
|
||||
ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.c_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_C,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
@@ -66,7 +106,7 @@ float invoke_gemm(int n_warmup,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <bool Persistent, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -87,6 +127,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
@@ -102,7 +151,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
@@ -150,8 +199,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -169,13 +218,11 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO Add support for kbatch > 1 in grouped gemm
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
invoke_gemm<ALayout, BLayout, CLayout, Persistent>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
@@ -183,7 +230,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
@@ -194,7 +241,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
@@ -211,6 +258,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <bool Persistent>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -227,12 +275,20 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
|
||||
@@ -127,4 +127,15 @@ struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to check if a type is a specialization of a given template
|
||||
template <typename Test, template <typename...> class RefTemplate>
|
||||
struct is_specialization_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <template <typename...> class RefTemplate, typename... Args>
|
||||
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
45
include/ck_tile/host/stream_utils.hpp
Normal file
45
include/ck_tile/host/stream_utils.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static inline index_t get_available_compute_units(const stream_config& s)
|
||||
{
|
||||
constexpr static uint32_t MAX_MASK_DWORDS = 64;
|
||||
|
||||
// assume at most 64*32 = 2048 CUs
|
||||
uint32_t cu_mask[MAX_MASK_DWORDS]{};
|
||||
|
||||
auto count_set_bits = [](uint32_t dword) {
|
||||
index_t count = 0;
|
||||
while(dword != 0)
|
||||
{
|
||||
if(dword & 0x1)
|
||||
{
|
||||
count++;
|
||||
}
|
||||
dword = dword >> 1;
|
||||
}
|
||||
return count;
|
||||
};
|
||||
|
||||
HIP_CHECK_ERROR(hipExtStreamGetCUMask(s.stream_id_, MAX_MASK_DWORDS, &cu_mask[0]));
|
||||
|
||||
index_t num_cu = 0;
|
||||
for(uint32_t i = 0; i < MAX_MASK_DWORDS; i++)
|
||||
{
|
||||
num_cu += count_set_bits(cu_mask[i]);
|
||||
}
|
||||
|
||||
return num_cu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -195,6 +195,22 @@ struct OffsettedTile1DPartitioner
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The function subtracts the block's start (offset) from a given block index.
|
||||
* @param [in] block_start Workgroup offset.
|
||||
* @param [in] M Gemm's M dimension.
|
||||
* @param [in] N Gemm's N dimension.
|
||||
* @param [in] block_idx Current block index of the workgroup.
|
||||
* @return Returns a `tuple` [Im, In] with shifted index.
|
||||
*/
|
||||
[[nodiscard]] CK_TILE_DEVICE static auto
|
||||
GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -230,7 +246,7 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
* @param N GEMM's N dimension.
|
||||
* @return index_t A total number of workgroups.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
|
||||
|
||||
@@ -5,10 +5,15 @@
|
||||
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmTransKernelArg
|
||||
@@ -22,6 +27,8 @@ struct GemmTransKernelArg
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
GemmTransKernelArg(GemmKernelArgs&& karg) : group_karg{karg}, block_start{0}, block_end{0} {}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -40,8 +47,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -51,19 +60,42 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
|
||||
{
|
||||
return group_count * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -121,39 +153,165 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
|
||||
{
|
||||
for(const auto& karg : kargs)
|
||||
{
|
||||
if(!Base::IsSupportedArgument(karg.group_karg))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
|
||||
Run(kargs.group_karg, block_idx_2d, block_idx_z);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmKernelArgs& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
const auto [iM, iN] = block_idx_2d;
|
||||
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
this->RunGemm(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
|
||||
* and the tail-number. This is needed for the persistent tile-loop when
|
||||
* we didn't have access to the K dimension on the host.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunEpilogue = [&](auto& c_block_tile) {
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I2);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
};
|
||||
|
||||
if constexpr(is_specialization_of<GemmPipeline, GemmPipelineAgBgCrCompV3>::value)
|
||||
{
|
||||
// Run the specific implementation with hotloop+tailnum config
|
||||
using PipelineImpl =
|
||||
typename GemmPipeline::template PipelineImpl<GemmPipeline::Scheduler>;
|
||||
const auto PassThrough = [](const auto& a) { return a; };
|
||||
if(has_hot_loop && tail_num == TailNumber::Full)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Full>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
else if(has_hot_loop && tail_num == TailNumber::Odd)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Odd>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
else if(has_hot_loop && tail_num == TailNumber::Even)
|
||||
{
|
||||
const auto& c_block_tile =
|
||||
PipelineImpl{}.template operator()<true, TailNumber::Even>(a_block_window,
|
||||
PassThrough,
|
||||
b_block_window,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
smem_ptr_0);
|
||||
RunEpilogue(c_block_tile);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ignore = a_block_window;
|
||||
ignore = b_block_window;
|
||||
static_assert(false, "GemmPipeline specialization not supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
|
||||
index_t block_id,
|
||||
index_t group_count) const
|
||||
{
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) >> 1);
|
||||
@@ -173,7 +331,61 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
group_id = index_t((left + right) >> 1);
|
||||
}
|
||||
|
||||
Run(gemm_desc_ptr[group_id]);
|
||||
return group_id;
|
||||
}
|
||||
|
||||
// For non-persistent kernels
|
||||
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
kargs.group_karg.M,
|
||||
kargs.group_karg.N,
|
||||
(block_id - kargs.block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
|
||||
}
|
||||
|
||||
// For persistent kernels
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
typename = void> // extra template parameter to avoid redefinition
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
|
||||
index_t cum_grid_size = 0;
|
||||
for(index_t group_id = 0; group_id < group_count; ++group_id)
|
||||
{
|
||||
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
|
||||
const auto& k_batch = kargs.k_batch;
|
||||
const auto block_start = cum_grid_size;
|
||||
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
|
||||
block_id = block_id + grid_size; // advance to next block
|
||||
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
|
||||
if(block_id >= cum_grid_size)
|
||||
{
|
||||
break; // exit the loop if all blocks are processed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -20,18 +20,19 @@ namespace ck_tile {
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
@@ -104,6 +105,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
|
||||
0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
@@ -38,7 +38,8 @@ template <bool kPadM_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -53,6 +54,27 @@ struct TileGemmUniversalTraits
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
@@ -8,19 +8,27 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
|
||||
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
@@ -14,12 +15,37 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, SplitK)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
// Get the persistent value from ck_tile::bool_constant
|
||||
using PersistentType = std::tuple_element_t<7, Tuple>;
|
||||
static constexpr bool Persistent = PersistentType::value;
|
||||
|
||||
struct GroupedGemKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
@@ -31,9 +35,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
static const ck_tile::index_t M_Tile = 256;
|
||||
static const ck_tile::index_t N_Tile = 256;
|
||||
static const ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
@@ -41,7 +45,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
@@ -53,7 +57,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
@@ -138,11 +142,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
@@ -163,7 +168,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
@@ -171,6 +176,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
std::cout << "Run without SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -178,6 +184,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run using SplitK" << std::endl;
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -213,6 +220,135 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -220,6 +356,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
const int kbatch = 1,
|
||||
const int group_count = 16)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
@@ -294,10 +431,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -315,18 +452,51 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO add support for kbatch > 1
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
// Generate kernel arguments
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = gemm_descs[0].k_batch > 1;
|
||||
for(const auto& arg : gemm_descs)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.c_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_C,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
ck_tile::hip_check_error(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
else
|
||||
{
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
@@ -340,7 +510,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user