mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM
* Updates
* Fixes
* rebase
* fix
* Fix
* fixes
* support for batched gemm
[ROCm/composable_kernel commit: af66494880]
This commit is contained in:
@@ -54,8 +54,7 @@ using CDataType = Types::CDataType;
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("b", "1", "batch size")
|
||||
.insert("m", "3840", "m dimension")
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
@@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t batch_size = arg_parser.get_int("b");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_size,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
#endif
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
#endif
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
@@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
args.p_c,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_C);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
|
||||
@@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
|
||||
@@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t batch_stride_B,
|
||||
ck_tile::index_t batch_stride_C,
|
||||
ck_tile::index_t batch_count,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
@@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
|
||||
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
|
||||
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
|
||||
// No additional shared memory needed
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
|
||||
{
|
||||
// TODO: At now CShuffle doesn't allow to vector store after permute.
|
||||
// It should be fixed and this function should return true.
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
|
||||
{
|
||||
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
|
||||
{
|
||||
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
|
||||
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
|
||||
// Store the tile data to the permuted location
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,21 +35,39 @@ struct Default2DEpilogue
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
using KernelArgs = BatchedGemmKernelArgs;
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count)
|
||||
__host__ static constexpr auto
|
||||
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
|
||||
{
|
||||
return TilePartitioner::GridSize(M, N, batch_count);
|
||||
return TilePartitioner::GridSize(M, N, KBatch * batch_count);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
|
||||
@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C},
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch},
|
||||
hostArgs.batch_stride_A,
|
||||
hostArgs.batch_stride_B,
|
||||
hostArgs.batch_stride_C,
|
||||
@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
|
||||
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
|
||||
|
||||
// options
|
||||
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
|
||||
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
|
||||
splitk_batch_offset.a_k_split_offset;
|
||||
|
||||
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
|
||||
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
|
||||
splitk_batch_offset.b_k_split_offset;
|
||||
|
||||
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
|
||||
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
|
||||
|
||||
this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
{
|
||||
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->template RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ struct GemmKernel
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t KBatch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
@@ -105,28 +106,72 @@ struct GemmKernel
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C};
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
|
||||
// const void* b_ptr,
|
||||
// void* c_ptr,
|
||||
// index_t M,
|
||||
// index_t N,
|
||||
// index_t K,
|
||||
// index_t stride_A,
|
||||
// index_t stride_B,
|
||||
// index_t stride_C)
|
||||
// {
|
||||
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.KBatch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
{
|
||||
constexpr bool is_output_c_reg_transposed =
|
||||
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
|
||||
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
|
||||
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_output_c_reg_transposed) ||
|
||||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
|
||||
{
|
||||
if(kargs.KBatch != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
@@ -198,17 +243,19 @@ struct GemmKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs) const
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
@@ -217,7 +264,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
@@ -229,7 +276,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
@@ -238,7 +285,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
@@ -248,7 +295,7 @@ struct GemmKernel
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
@@ -257,7 +304,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
@@ -270,7 +317,7 @@ struct GemmKernel
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
@@ -330,8 +377,8 @@ struct GemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& a_block_window = make_tile_window(
|
||||
@@ -363,23 +410,27 @@ struct GemmKernel
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n) const
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
;
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -389,18 +440,43 @@ struct GemmKernel
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile);
|
||||
|
||||
constexpr bool is_output_c_reg_transposed =
|
||||
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
|
||||
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
|
||||
(GemmPipeline::VectorSizeC % 2 == 0 &&
|
||||
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_output_c_reg_transposed))
|
||||
{
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
|
||||
@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
|
||||
@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
|
||||
@@ -13,6 +13,8 @@ namespace ck_tile {
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
#if 0
|
||||
// 2d
|
||||
template <typename Problem>
|
||||
@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
index_t smem_size = 0;
|
||||
smem_size += smem_size_a + smem_size_b;
|
||||
constexpr index_t smem_size = smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
|
||||
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
|
||||
@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = 1;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
|
||||
@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user