[CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520)

* Init commit new API

* apply clang-format

* PreShuffle preapring

* Apply Preshuffle condition to universal_gemm

* Fix: convert size_t to index_t

* Review changes

* Mode 100755 -> 100644

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: b507d889c1]
This commit is contained in:
Mateusz Ozga
2025-07-24 20:39:56 +02:00
committed by GitHub
parent 523bfd1f91
commit 0c0fd440ca
28 changed files with 2094 additions and 1519 deletions

View File

@@ -24,7 +24,7 @@ template <typename GemmConfig,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)

View File

@@ -475,4 +475,4 @@ template <typename ADataType,
typename CLayout,
bool Persistent = false,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -25,7 +25,7 @@ template <typename GemmConfig,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -74,119 +74,120 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)

View File

@@ -158,7 +158,7 @@ template <typename GemmConfig,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
@@ -185,18 +185,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_repeat,
bool persistent)
{
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{},
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};
float ave_time;
if(persistent)

View File

@@ -25,7 +25,7 @@ template <typename GemmConfig,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -74,120 +74,121 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)

View File

@@ -50,21 +50,20 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_E = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_E = batch_stride_C;
args.batch_count = batch_count;
ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count};
float ave_time = batched_gemm<ADataType,
BDataType,

View File

@@ -54,7 +54,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{

View File

@@ -83,18 +83,18 @@ float invoke_gemm(int n_warmup,
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
arg.b_ptr,
{},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
{},
arg.stride_E,
arg.k_batch});
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
{arg.b_ptr},
{/*arg.ds_ptr*/},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
{/*arg.stride_Ds*/},
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
@@ -240,7 +240,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
invoke_gemm<ADataType,

View File

@@ -157,7 +157,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);

View File

@@ -64,7 +64,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
using gemm_multi_d_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs<DsDataType::size()>;
template <typename ADataType,
typename BDataType,

View File

@@ -262,6 +262,8 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
return flag;
}
CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; }
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }

View File

@@ -28,6 +28,8 @@
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"

View File

@@ -9,35 +9,41 @@
namespace ck_tile {
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
/// @brief The Batched GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<>
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(a_ptr_,
b_ptr_,
{},
c_ptr_,
k_batch_,
M_,
N_,
K_,
stride_A_,
stride_B_,
{},
stride_C_),
CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: UniversalGemmHostArgs<>({a_ptr_},
{b_ptr_},
{/*ds_ptr*/},
c_ptr_,
k_batch_,
M_,
N_,
K_,
{stride_A_},
{stride_B_},
{/*stride_Ds_*/},
stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_E(batch_stride_C_),
@@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
struct BatchedGemmKernel
{
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
using CDataType = typename Base::EDataType;
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using TilePartitioner = typename Base::TilePartitioner;
using GemmPipeline = typename Base::GemmPipeline;
using EpiloguePipeline = typename Base::EpiloguePipeline;
using ALayout = typename Base::ALayout;
using BLayout = typename Base::BLayout;
using CLayout = typename Base::ELayout;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
struct BatchedGemmKernelArgs : GemmKernelArgs
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
index_t batch_stride_A;
index_t batch_stride_B;
@@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
[[nodiscard]] CK_TILE_HOST static auto GetName() -> const std::string
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
{
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return dim3(UniversalGemmKernel::KernelBlockSize);
}
CK_TILE_HOST static constexpr BatchedGemmKernelArgs
MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
{
return BatchedGemmKernelArgs{{hostArgs.a_ptr,
hostArgs.b_ptr,
{},
return BatchedGemmKernelArgs{{hostArgs.as_ptr,
hostArgs.bs_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
{},
hostArgs.stride_As,
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch},
hostArgs.batch_stride_A,
@@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
@@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
splitk_batch_offset.a_k_split_offset;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + batch_offset_A +
splitk_batch_offset.as_k_split_offset[0];
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
splitk_batch_offset.b_k_split_offset;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + batch_offset_B +
splitk_batch_offset.bs_k_split_offset[0];
const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
@@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
UniversalGemmKernel::RunGemm(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,185 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/// @brief The MultiD GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides. NumDTensor
/// describes the number of D tensors.
template <index_t NumDTensor = 1>
struct GemmMultiDHostArgs
{
CK_TILE_HOST GemmMultiDHostArgs() = default;
CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernelMultiD
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ALayout>::value &&
!is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, BLayout>::value &&
!is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars.");
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"ELayout and EDataType must be scalars.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&
is_detected<is_tuple, DsDataType>::value &&
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
"DsLayout and DsDataType must be tuples and must have the same size.");
/// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by
/// the user."
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;
static constexpr index_t NumDTensor = DsDataType::size();
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel::GridSize(M, N, KBatch);
}
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const GemmMultiDHostArgs<NumDTensor>& hostArgs) ->
typename UniversalGemmKernel::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B, and D.
return UniversalGemmKernel::MakeKernelArgs(
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>({hostArgs.a_ptr},
{hostArgs.b_ptr},
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
{hostArgs.stride_A},
{hostArgs.stride_B},
hostArgs.stride_Ds,
hostArgs.stride_E));
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
{
UniversalGemmKernel{}.template operator()(kargs);
}
};
} // namespace ck_tile

View File

@@ -16,37 +16,116 @@
namespace ck_tile {
/// @brief The Grouped GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct GroupedGemmHostArgs
{
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
struct GemmTransKernelArg
{
GemmKernelArgs<> group_karg;
UniversalGemmKernelArgs<> group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = delete;
GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {}
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
{
}
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
struct GroupedGemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using Base = UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
//// @brief Specify the layout configurations for A, B, C/E
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
@@ -65,8 +144,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
}
@@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -107,8 +185,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;
@@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
grid_size += grid_size_grp;
auto karg = GemmKernelArgs<>{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
{},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
stride_a,
stride_b,
{},
stride_e,
gemm_descs[i].k_batch};
auto karg =
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{/*ds_ptr*/},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{/*stride_ds*/},
stride_e,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
Run(kargs.group_karg, block_idx_2d, block_idx_z);
}
CK_TILE_DEVICE void Run(const GemmKernelArgs<>& kargs,
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
@@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
splitk_batch_offset.bs_k_split_offset[0];
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// allocate LDS
@@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
else
{
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
@@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
@@ -234,7 +321,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const GemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset);
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
@@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template

File diff suppressed because it is too large Load Diff

View File

@@ -242,21 +242,20 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = StrideA;
args.stride_B = StrideB;
args.stride_E = StrideC;
args.batch_stride_A = BatchStrideA;
args.batch_stride_B = BatchStrideB;
args.batch_stride_E = BatchStrideC;
args.batch_count = BatchCount;
ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
1,
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
ck_tile::stream_config{nullptr, false});

View File

@@ -25,7 +25,7 @@ template <typename GemmConfig,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)

View File

@@ -158,7 +158,7 @@ template <typename GemmConfig,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
@@ -185,18 +185,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_repeat,
bool persistent)
{
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{},
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};
float ave_time;
if(persistent)

View File

@@ -411,4 +411,4 @@ template <typename ADataType,
typename CLayout,
bool Persistent = false,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -14,7 +14,7 @@ template <typename GemmConfig,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -63,119 +63,120 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw ArgumentsNotSupportedException(
"Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw ArgumentsNotSupportedException(
"Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)

View File

@@ -91,8 +91,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
const ck_tile::stream_config& s)
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
@@ -324,9 +323,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
return stride;
};
std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
@@ -345,17 +344,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_E = stride_C;
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});

View File

@@ -10,7 +10,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
struct ElementWiseAddAdd
@@ -95,7 +95,7 @@ class TestCkTileGemmMultiD : public ::testing::Test
typename DsLayout,
typename ELayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
void invoke_gemm_multi_d(const ck_tile::GemmHostArgs<DsDataType::size()>& args,
void invoke_gemm_multi_d(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
@@ -189,7 +189,7 @@ class TestCkTileGemmMultiD : public ::testing::Test
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
@@ -345,18 +345,18 @@ class TestCkTileGemmMultiD : public ::testing::Test
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
ck_tile::GemmHostArgs<DsDataType::size()> args({a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
k_batch,
M,
N,
K,
StrideA,
StrideB,
stridesDs,
StrideE});
ck_tile::GemmMultiDHostArgs<DsDataType::size()> args({a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
k_batch,
M,
N,
K,
StrideA,
StrideB,
stridesDs,
StrideE});
invoke_gemm_multi_d<ADataType,
BDataType,

View File

@@ -86,8 +86,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
const ck_tile::stream_config& s)
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
// constexpr ck_tile::index_t M_Tile = 128;
@@ -314,9 +313,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
return stride;
};
std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{});
ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{});
ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
@@ -346,17 +345,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_E = stride_C;
ck_tile::GemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
stride_C};
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});

View File

@@ -51,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
@@ -437,7 +437,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
@@ -451,18 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test
const bool splitk = gemm_descs[0].k_batch > 1;
for(const auto& arg : gemm_descs)
{
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
arg.b_ptr,
{},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
{},
arg.stride_E,
arg.k_batch});
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
{arg.b_ptr},
{/*arg.ds_ptr*/},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
{/*arg.stride_Ds*/},
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, false, 1};
ck_tile::hip_check_error(

View File

@@ -233,7 +233,7 @@ struct GemmKernel {{
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
@@ -335,7 +335,7 @@ struct GemmKernel {{
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {{
@@ -680,7 +680,7 @@ struct GemmDispatcher {
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
@@ -705,7 +705,7 @@ struct GemmDispatcher {
warp_tile_n,
warp_tile_k,
) = tile[j]
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = (
@@ -746,7 +746,7 @@ struct GemmDispatcher {
content += """ }
template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);

View File

@@ -22,7 +22,7 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
@@ -89,10 +89,9 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs<> gemm_args = {
ck_tile::GemmHostArgs gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{}, // ds_ptr
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
@@ -100,7 +99,6 @@ class GemmProfiler
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
{}, // stride_Ds
gemm_problem.stride_c_,
};