mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM
* Updates
* Fixes
* rebase
* fix
* Fix
* fixes
* support for batched gemm
[ROCm/composable_kernel commit: af66494880]
This commit is contained in:
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("b", "1", "batch size")
|
||||
.insert("m", "3840", "m dimension")
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t batch_size = arg_parser.get_int("b");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_size,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
#endif
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
#endif
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
args.p_c,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_C);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
|
||||
@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
|
||||
@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t batch_stride_B,
|
||||
ck_tile::index_t batch_stride_C,
|
||||
ck_tile::index_t batch_count,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
|
||||
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
|
||||
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user