Universal gemm splitk using reduce (with multi-d) (#1341)

* init for reduce_threadwise multi_d

* add reduce_threadwise_multi_d

* add reduce_multi_d

* clean

* start add an other splitk device op

* add reduce template parameter to SplitKBatchOffset

* add reduce c matrix

* clean up code

* change example data type to bf16

* add bf16Ai8B example

* remove reduce template parameter

* add splitk atomic status to v4

* example add multi d parameters

* device op add multi-d parameters

* add multi-d to reduce

* fix kbach=1 bug

* change B layout to col in  bf16Ai8B example

* remove float adding struct

* change  multi-d interface

* change file and class name

* remove multi-d of bf16Ai8B example

* change IsReduce function to IsReduceAdd

* change example layout to RRR from RCR

* according layout to set ds stride

* reset parameter layout

* add gemm universal reduce instance

* add reduce factory

* add profile_gemm_universal_reduce

* add reduce to profiler

* fix reduce instance

* fix profiler reduce compiling bug

* format

* format library instance code

* add mem instance for reduce library

* fix call instance names

* add workspace for reduce in ckProfiler

* format

* add mnpading to reduce library instance

* add fp16 instance to reduce of profiler

* change copyright time

* restore profiler cmake file

* add reduce text to instances

* add DsLayout and DsDataType to instances template parameter

* fixed gemm_reduce_multi_d

* add an example without multi_d

* Update common.hpp

* Update gtest.cmake

* Update gemm_xdl_splitk_reduce_bf16.cpp

* clean

* Update gtest.cmake

* format

* fixe api

* format

* default parameter change to RRR

* add vector_len for multi_d

* format

* Update gtest.cmake

* fix bf16A iBB elementwiseop

* add ReduceDataType

* move ReduceDataType to end position

* format

* remove googletest git method  address

* fix copyright time

* update init data

---------

Co-authored-by: root <jizhan@amd.com>
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: Jing Zhang <jizhan@meta.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
ltqin
2024-07-19 22:01:22 +08:00
committed by GitHub
parent 70a814f163
commit c544eb4da0
48 changed files with 4746 additions and 12 deletions

View File

@@ -0,0 +1,260 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/tuple_helper.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename DsGridDesc_M,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation,
typename DsGridPointer>
__global__ void
kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k,
const DsGridDesc_M ds_grid_desc_m,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op,
const InDataType* const __restrict__ p_in_value_global,
const DsGridPointer p_ds_value_global,
OutDataType* const __restrict__ p_out_value_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
ds_grid_desc_m,
out_grid_desc_m,
in_elementwise_op,
out_elementwise_op,
p_in_value_global,
p_ds_value_global,
p_out_value_global);
}
template <typename InDataType,
typename DsDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename DsGridDesc_M,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize,
typename DsVectorSize>
struct GridwiseReduction_mk_to_m_threadwise_multi_d
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr index_t NumDTensor = DsDataType::Size();
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const DsGridDesc_M& ds_grid_desc_m,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& out_elementwise_op,
const InDataType* const __restrict__ p_in_value_global,
const DsGridPointer p_ds_grid,
OutDataType* const __restrict__ p_out_value_global)
{
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
false>;
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0;
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
auto ds_thread_buf = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MThreadSliceSize, true>{};
},
Number<NumDTensor>{});
auto ds_global_buf = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto ds_global_load = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(ds_grid_desc_m[I]),
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>, // SliceLengths
Sequence<0>, // DimAccessOrder
InSrcVectorDim, // SrcVectorDim
DsVectorSize{}[I],
1, // SrcScalarStrideInVector
true>{
ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)};
},
Number<NumDTensor>{});
static_for<0, NumDTensor, 1>{}([&](auto I) {
ds_global_load(I).Run(ds_grid_desc_m[I],
ds_global_buf[I],
reduced_data_desc,
make_tuple(I0),
ds_thread_buf(I));
});
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> out_value_buf;
// if constexpr(NumDTensor > 0)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(accu_value_buf[I]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; },
Number<NumDTensor>{}));
unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs);
});
}
auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<OutDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThrough,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
OutMemoryDataOperation,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThrough{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf);
}
};
} // namespace ck

View File

@@ -42,7 +42,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
#else
@@ -73,7 +73,7 @@ __global__ void
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0,
p_shared_1,
karg);
@@ -531,21 +531,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t k_batch_)
index_t k_batch_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
p_c_grid{p_c_grid_},
is_reduce(is_reduce_)
{
}
__host__ __device__ inline bool IsReduceAdd() const
{
return (Problem::KBatch > 1) && is_reduce;
}
__host__ __device__ inline bool IsAtomicAdd() const
{
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
@@ -574,10 +588,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
}
else
{
c_reduce_offset = 0;
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t c_reduce_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1080,16 +1104,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
if(!karg.IsReduceAdd())
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
if(karg.KBatch > 1)
{
return false;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
if(karg.KBatch > 1)
{
return false;
}
}
}