mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Gemm reduce max (#209)
* [What] Rename the example [Why] Prepare to add unary reduction * Add global oparation to the parameter * Add atomicmax * Fix compile error * Support atomicMax (hip library) * Rename the reduction example * Fix target name * use p_d1_grid as the indicator directly * Prevent performance issue. Let passthrough handle it. * Implement the function template the specialize the float2 * No need to separate into two lines * Remove empty line * add comment * Fix compile error due to merge from develop * make the implementation of atomic_max / atomic_add explicit for each datatype * Refine typo * For future CI test * Fix compiler error in ckProfiler * Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe' * simply use remove_pointer * Rename type and var * Refine example * Modify reducemax example * Fix bug in reduction * Change initialize range * Implement F64 version of atomicMax * Move reduction code together * Add buffer atomic_max * Fix coding style by clang-format * Integrate new api of DeviceGemmReduce_Xdl_CShuffle * Integrate Batch gemm reduction * Fix example * fix example * clean up * Fix batch gemm tensor operation * Fix coding style * Fix template augument * Fix clang format * Keep flexible of different stride for each D tensor * Fix compile error for ckProfiler * Fix typo * [What] Fix naming [Why] Prepare to add out elementop * Add DoutElementOp Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
@@ -15,11 +15,12 @@ namespace ck {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -34,12 +35,12 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -53,13 +54,13 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_d0_grid,
|
||||
p_d1_grid,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -69,12 +70,12 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_d1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -88,15 +89,15 @@ template <typename FloatAB,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename FloatReduceAcc,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -357,13 +358,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const D1ElementwiseOperation& d1_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
@@ -377,10 +378,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
@@ -527,7 +524,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
// shuffle C + reduction + write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
@@ -666,6 +663,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
// LDS c_reduce_block_desc_mperblock_nperblock
|
||||
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -716,16 +736,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
// reduce: threadwise copy from LDS to VGPR
|
||||
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
|
||||
@@ -749,47 +762,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
1,
|
||||
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
|
||||
|
||||
// reduce: copy from VGPR to global
|
||||
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
FloatD,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto p_d_grid = p_ds_grid[I];
|
||||
auto d_out_element_op = dxs_out_element_op[I];
|
||||
|
||||
auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global;
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
remove_pointer_t<decltype(p_d_grid)>,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
decltype(d_out_element_op),
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation::At(I),
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
d_out_element_op};
|
||||
},
|
||||
Number<p_ds_grid.Size()>{});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -816,64 +811,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
using ThreadwiseReduce_D0 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D0ReduceOperation,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce_D1 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D1ReduceOperation,
|
||||
false>;
|
||||
|
||||
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
|
||||
|
||||
// reduce
|
||||
// TODO - extract following into reduction_blockwise
|
||||
{
|
||||
// copy from LDS to VGPR
|
||||
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_reduce_thread_desc_mperblock_nperblock,
|
||||
make_tuple(I0, I0),
|
||||
c_reduce_thread_buf);
|
||||
|
||||
// reduce in VGPR
|
||||
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_d_grid = p_ds_grid[In];
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
|
||||
auto d_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto& d_in_element_op = dxs_in_element_op[In];
|
||||
|
||||
auto& d_reduce_thread_copy_vgpr_to_global =
|
||||
dxs_reduce_thread_copy_vgpr_to_global(In);
|
||||
|
||||
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
|
||||
using ThreadwiseReduce =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
DReduceOperation,
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d_reduce_thread_copy_vgpr_to_global.Run(
|
||||
d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d0_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d0_grid_buf);
|
||||
|
||||
d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d1_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d1_grid_buf);
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
@@ -883,18 +887,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
// move on D0
|
||||
d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
|
||||
// move on D1
|
||||
d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
// Reduction
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user