mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM * Updates * Fixes * rebase * fix * Fix * fixes * support for batched gemm
This commit is contained in:
@@ -93,6 +93,7 @@ struct GemmKernel
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t KBatch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
@@ -105,28 +106,72 @@ struct GemmKernel
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C};
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
|
||||
// const void* b_ptr,
|
||||
// void* c_ptr,
|
||||
// index_t M,
|
||||
// index_t N,
|
||||
// index_t K,
|
||||
// index_t stride_A,
|
||||
// index_t stride_B,
|
||||
// index_t stride_C)
|
||||
// {
|
||||
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.KBatch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
{
|
||||
constexpr bool is_output_c_reg_transposed =
|
||||
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
|
||||
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
|
||||
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_output_c_reg_transposed) ||
|
||||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
|
||||
{
|
||||
if(kargs.KBatch != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
@@ -198,17 +243,19 @@ struct GemmKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs) const
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
@@ -217,7 +264,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
@@ -229,7 +276,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
@@ -238,7 +285,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
@@ -248,7 +295,7 @@ struct GemmKernel
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
@@ -257,7 +304,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
@@ -270,7 +317,7 @@ struct GemmKernel
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
@@ -330,8 +377,8 @@ struct GemmKernel
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& a_block_window = make_tile_window(
|
||||
@@ -363,23 +410,27 @@ struct GemmKernel
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n) const
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
;
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -389,18 +440,43 @@ struct GemmKernel
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile);
|
||||
|
||||
constexpr bool is_output_c_reg_transposed =
|
||||
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
|
||||
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
|
||||
(GemmPipeline::VectorSizeC % 2 == 0 &&
|
||||
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
|
||||
is_output_c_reg_transposed))
|
||||
{
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n);
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user