Gemm reduce max (#209)

* [What] Rename the example
[Why] Prepare to add unary reduction

* Add global oparation to the parameter

* Add atomicmax

* Fix compile error

* Support atomicMax (hip library)

* Rename the reduction example

* Fix target name

* use p_d1_grid as the indicator directly

* Prevent performance issue. Let passthrough handle it.

* Implement the function template the specialize the float2

* No need to separate into two lines

* Remove empty line

* add comment

* Fix compile error due to merge from develop

* make the implementation of atomic_max / atomic_add explicit for each datatype

* Refine typo

* For future CI test

* Fix compiler error in ckProfiler

* Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'

* simply use remove_pointer

* Rename type and var

* Refine example

* Modify reducemax example

* Fix bug in reduction

* Change initialize range

* Implement F64 version of atomicMax

* Move reduction  code together

* Add buffer atomic_max

* Fix coding style by clang-format

* Integrate new api of DeviceGemmReduce_Xdl_CShuffle

* Integrate Batch gemm reduction

* Fix example

* fix example

* clean up

* Fix batch gemm tensor operation

* Fix coding style

* Fix template augument

* Fix clang format

* Keep flexible of different stride for each D tensor

* Fix compile error for ckProfiler

* Fix typo

* [What] Fix naming
[Why] Prepare to add out elementop

* Add DoutElementOp

Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
rocking5566
2022-05-20 10:56:56 +08:00
committed by GitHub
parent aafc3ac27a
commit 0ffe956ab1
28 changed files with 1298 additions and 626 deletions

View File

@@ -15,11 +15,12 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatD,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -34,12 +35,12 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
DPtrsGlobal p_ds_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D1ElementwiseOperation d1_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -53,13 +54,13 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_d0_grid,
p_d1_grid,
p_ds_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
d1_element_op,
dxs_in_element_op,
dxs_out_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
@@ -69,12 +70,12 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_d0_grid;
ignore = p_d1_grid;
ignore = p_ds_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d1_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -88,15 +89,15 @@ template <typename FloatAB,
typename FloatCShuffle,
typename FloatC,
typename FloatReduceAcc,
typename FloatD,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
@@ -357,13 +358,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const D1ElementwiseOperation& d1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsOutElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
@@ -377,10 +378,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
@@ -527,7 +524,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
// shuffle C + reduction + write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
@@ -666,6 +663,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// TODO: this should be implemented as a blockwise reduction
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
@@ -716,16 +736,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto d_reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
@@ -749,47 +762,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
// reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation,
1,
false>{d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
ck::tensor_operation::element_wise::PassThrough{}};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_d_grid = p_ds_grid[I];
auto d_out_element_op = dxs_out_element_op[I];
auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global;
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
decltype(d_out_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I),
1,
false>{d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op};
},
Number<p_ds_grid.Size()>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
@@ -816,64 +811,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
using ThreadwiseReduce_D0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D0ReduceOperation,
false>;
using ThreadwiseReduce_D1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D1ReduceOperation,
false>;
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
// reduce
// TODO - extract following into reduction_blockwise
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
// reduce in VGPR
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In];
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
auto d_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In];
auto& d_reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
DReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run(
d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d_thread_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d0_thread_buf,
d_grid_desc_mblock_mperblock,
d0_grid_buf);
d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d1_thread_buf,
d_grid_desc_mblock_mperblock,
d1_grid_buf);
}
if constexpr(access_id < num_access - 1)
@@ -883,18 +887,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on D0
d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
// Reduction
}
}
};