mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
external api for gemm + layernorm (#285)
* Extract base class for elementwise
* Refactor interface of DeviceGemmReduce. Do not use tuple in interface
* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add
* Unify base class of gemm + reduce and gemm + bias + add + reduce
* 1. Rename gemm_bias_add_reduce for external api
2. Refine cmake
* Add normalize device operation
* [What] Reorder the argument
[Why] Because d0 is also the input of c.
* Add type string
* Add example of gemm_bias_add_layernorm via external api
* Refactor example code
* clang-format
* Fix compile error
* clang-format
* Add external api for gemm_add_add_layernorm and normalize
* Add client example
* clang-format
[ROCm/composable_kernel commit: 12235112a1]
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
#include "ck/device_utility/kernel_launch.hpp"
|
||||
@@ -35,7 +35,7 @@ template <typename ADataType,
|
||||
index_t DScalarPerVector,
|
||||
index_t EScalarPerVector,
|
||||
index_t FScalarPerVector>
|
||||
struct Device5AryElementwise : public BaseOperator
|
||||
struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
|
||||
return true;
|
||||
};
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const CDataType* p_c,
|
||||
const DDataType* p_d,
|
||||
const EDataType* p_e,
|
||||
FDataType* p_f,
|
||||
static auto MakeArgument(std::array<const void*, 5> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
|
||||
std::vector<index_t> f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d,
|
||||
p_e,
|
||||
p_f,
|
||||
return Argument{static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<const CDataType*>(p_inputs[2]),
|
||||
static_cast<const DDataType*>(p_inputs[3]),
|
||||
static_cast<const EDataType*>(p_inputs[4]),
|
||||
static_cast<FDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
|
||||
functor};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_c,
|
||||
const void* p_d,
|
||||
const void* p_e,
|
||||
void* p_f,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
std::vector<index_t> d_strides,
|
||||
std::vector<index_t> e_strides,
|
||||
std::vector<index_t> f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, 5> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const CDataType*>(p_c),
|
||||
static_cast<const DDataType*>(p_d),
|
||||
static_cast<const EDataType*>(p_e),
|
||||
static_cast<FDataType*>(p_f),
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<const CDataType*>(p_inputs[2]),
|
||||
static_cast<const DDataType*>(p_inputs[3]),
|
||||
static_cast<const EDataType*>(p_inputs[4]),
|
||||
static_cast<FDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
d_strides,
|
||||
e_strides,
|
||||
f_strides,
|
||||
input_strides[0],
|
||||
input_strides[1],
|
||||
input_strides[2],
|
||||
input_strides[3],
|
||||
input_strides[4],
|
||||
output_strides[0],
|
||||
functor);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
};
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "Device5aryElementwise"
|
||||
<< "<"
|
||||
<< "NDim = " << NDim
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< "AScalarPerVector = " << AScalarPerVector
|
||||
<< "BScalarPerVector = " << BScalarPerVector
|
||||
<< "CScalarPerVector = " << CScalarPerVector
|
||||
<< "DScalarPerVector = " << DScalarPerVector
|
||||
<< "EScalarPerVector = " << EScalarPerVector
|
||||
<< "FScalarPerVector = " << FScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceBatchedGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t Batch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceBatchedGemmReducePtr =
|
||||
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -23,16 +23,16 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DGridDescriptor_MBlock_MPerBlock,
|
||||
typename ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename ComputeBasePrtOfBatch,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainK0BlockLoop>
|
||||
@@ -44,18 +44,18 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
@@ -71,10 +71,10 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
|
||||
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
|
||||
p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
|
||||
});
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -82,33 +82,33 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_ds_grid,
|
||||
p_reduces_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_reduces_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = reduce_in_element_ops;
|
||||
ignore = reduce_out_element_ops;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = reduce_grid_desc_mblock_mperblock;
|
||||
ignore = compute_base_ptr_of_batch_;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
@@ -126,14 +126,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename ReduceGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -168,12 +168,7 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// assume D is packed tensor
|
||||
static auto MakeDGridDescriptor_M(index_t MRaw)
|
||||
static auto MakeReduceGridDescriptor_M(index_t MRaw)
|
||||
{
|
||||
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
|
||||
|
||||
@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
|
||||
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceOperations,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
ReduceGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
DGridDesc_M,
|
||||
ReduceGridDesc_M,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
index_t Batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
p_reduces_grid_{p_reduces_grid},
|
||||
Batch_(Batch),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
reduce_grid_desc_mblock_mperblock_{},
|
||||
compute_base_ptr_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
reduce_in_element_ops_{reduce_in_element_ops},
|
||||
reduce_out_element_ops_{reduce_out_element_ops}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
d_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
|
||||
reduce_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
ReducePtrsGlobal p_reduces_grid_;
|
||||
index_t Batch_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
ReduceGridDesc_M reduce_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
|
||||
reduce_grid_desc_mblock_mperblock_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
ReduceInElementwiseOperations reduce_in_element_ops_;
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
|
||||
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
@@ -824,38 +820,76 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
static constexpr int NumReduce = ReduceOperations::Size();
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 0> p_ds,
|
||||
void* p_c,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 0> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 0> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op,
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
(void)p_bias;
|
||||
(void)p_ds;
|
||||
(void)StrideDs;
|
||||
(void)d_element_ops;
|
||||
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
Batch};
|
||||
}
|
||||
|
||||
@@ -865,37 +899,73 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 0> p_ds,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t Batch) override
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 0> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 0> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op,
|
||||
index_t Batch = 1) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
(void)p_bias;
|
||||
(void)p_ds;
|
||||
(void)StrideDs;
|
||||
(void)d_element_ops;
|
||||
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
Batch);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
#include "ck/device_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -25,7 +26,7 @@ template <typename ADataType,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, 2> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<CDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
input_strides[0],
|
||||
input_strides[1],
|
||||
output_strides[0],
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "NDim = " << NDim
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< "AScalarPerVector = " << AScalarPerVector
|
||||
<< "BScalarPerVector = " << BScalarPerVector
|
||||
<< "CScalarPerVector = " << CScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NumInputTensor,
|
||||
ck::index_t NumOutputTensor,
|
||||
index_t NDim,
|
||||
typename ElementwiseFunctor>
|
||||
struct DeviceElementwise : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
|
||||
std::array<void*, NumOutputTensor> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <ck::index_t NumInputTensor,
|
||||
ck::index_t NumOutputTensor,
|
||||
index_t NDim,
|
||||
typename ElementwiseFunctor>
|
||||
using DeviceElementwisePtr =
|
||||
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -29,20 +29,20 @@ template <typename ALayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename BiasDataType,
|
||||
typename D0DataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename D0ElementwiseOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename ReduceGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -77,13 +77,7 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
: public DeviceGemmBiasAddReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
|
||||
{
|
||||
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// assume D is packed tensor
|
||||
static auto MakeDGridDescriptor_M(index_t MRaw)
|
||||
static auto MakeReduceGridDescriptor_M(index_t MRaw)
|
||||
{
|
||||
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
|
||||
|
||||
@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
|
||||
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
|
||||
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
BiasDataType,
|
||||
D0DataType,
|
||||
ReduceAccDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
D0ElementwiseOperation,
|
||||
ReduceOperations,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
ReduceGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
DGridDesc_M,
|
||||
ReduceGridDesc_M,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
const C0DataType* p_c0_grid,
|
||||
const C1DataType* p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const BiasDataType* p_bias_grid,
|
||||
const D0DataType* p_d0_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
D0ElementwiseOperation d0_element_op,
|
||||
ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_c0_grid_{p_c0_grid},
|
||||
p_c1_grid_{p_c1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
p_bias_grid_{p_bias_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_reduces_grid_{p_reduces_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
|
||||
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
reduce_grid_desc_mblock_mperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
c1_element_op_{c1_element_op},
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
d0_element_op_{d0_element_op},
|
||||
reduce_in_element_ops_{reduce_in_element_ops},
|
||||
reduce_out_element_ops_{reduce_out_element_ops}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
d_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
|
||||
reduce_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const C0DataType* p_c0_grid_;
|
||||
const C1DataType* p_c1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
const BiasDataType* p_bias_grid_;
|
||||
const D0DataType* p_d0_grid_;
|
||||
ReducePtrsGlobal p_reduces_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
ReduceGridDesc_M reduce_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
|
||||
reduce_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
C1ElementwiseOperation c1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
D0ElementwiseOperation d0_element_op_;
|
||||
ReduceInElementwiseOperations reduce_in_element_ops_;
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DPtrsGlobal,
|
||||
BiasDataType,
|
||||
D0DataType,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
D0ElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_bias_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.d0_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DPtrsGlobal,
|
||||
BiasDataType,
|
||||
D0DataType,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
D0ElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_bias_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.d0_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const C0DataType* p_c0,
|
||||
const C1DataType* p_c1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
static constexpr int NumReduce = ReduceOperations::Size();
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 1> p_ds,
|
||||
void* p_c,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 1> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 1> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_c0,
|
||||
p_c1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
D0ElementwiseOperation d_element_op =
|
||||
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
|
||||
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
static_cast<const D0DataType*>(p_ds[0]),
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
StrideDs[0],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op};
|
||||
d_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 1> p_ds,
|
||||
void* p_c,
|
||||
const void* p_c0,
|
||||
const void* p_c1,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 1> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 1> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
D0ElementwiseOperation d_element_op =
|
||||
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<const C0DataType*>(p_c0),
|
||||
static_cast<const C1DataType*>(p_c1),
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
static_cast<const D0DataType*>(p_ds[0]),
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
StrideDs[0],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
d_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmReduce_Xdl_CShuffle"
|
||||
str << "DeviceGemmBiasAddReduce_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
|
||||
@@ -9,91 +9,34 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
template <ck::index_t NumDTensor, ck::index_t NumReduce>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, NumDTensor> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_ops,
|
||||
std::array<void*, NumReduce> reduce_out_element_ops,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceGemmBiasAddReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const void* p_c0,
|
||||
const void* p_c1,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceGemmBiasAddReducePtr =
|
||||
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
template <ck::index_t NumDTensor, ck::index_t NumReduce>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -32,14 +32,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename ReduceGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -74,11 +74,7 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
|
||||
{
|
||||
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
}
|
||||
}
|
||||
|
||||
// assume D is packed tensor
|
||||
static auto MakeDGridDescriptor_M(index_t MRaw)
|
||||
// assume Reduce is packed tensor
|
||||
static auto MakeReduceGridDescriptor_M(index_t MRaw)
|
||||
{
|
||||
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
|
||||
|
||||
@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
|
||||
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceOperations,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
ReduceGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
DGridDesc_M,
|
||||
ReduceGridDesc_M,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
p_reduces_grid_{p_reduces_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
reduce_grid_desc_mblock_mperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
reduce_in_element_ops_{reduce_in_element_ops},
|
||||
reduce_out_element_ops_{reduce_out_element_ops}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
d_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
|
||||
reduce_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
ReducePtrsGlobal p_reduces_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
ReduceGridDesc_M reduce_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
|
||||
reduce_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
ReduceInElementwiseOperations reduce_in_element_ops_;
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
|
||||
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DPtrsGlobal,
|
||||
ReducePtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
ReduceInElementwiseOperations,
|
||||
ReduceAccElementwiseOperations,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_reduces_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.reduce_in_element_ops_,
|
||||
arg.reduce_out_element_ops_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.reduce_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
static constexpr int NumReduce = ReduceOperations::Size();
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 0> p_ds,
|
||||
void* p_c,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 0> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 0> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
(void)p_bias;
|
||||
(void)p_ds;
|
||||
(void)StrideDs;
|
||||
(void)d_element_ops;
|
||||
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op};
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, 0> p_ds,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, 0> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, 0> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_op,
|
||||
std::array<void*, NumReduce> reduce_out_element_op,
|
||||
ck::index_t = 1) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
(void)p_bias;
|
||||
(void)p_ds;
|
||||
(void)StrideDs;
|
||||
(void)d_element_ops;
|
||||
|
||||
ReducePtrsGlobal reduce_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReducePtrsGlobal{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return static_cast<T*>(p_reduces[I]);
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceInElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_in_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto tmp = ReduceAccElementwiseOperations{}[I];
|
||||
using T = remove_pointer_t<decltype(tmp)>;
|
||||
return *(static_cast<T*>(reduce_out_element_op[I]));
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
AElementwiseOperation a_element_op =
|
||||
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
|
||||
BElementwiseOperation b_element_op =
|
||||
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
|
||||
CElementwiseOperation c_element_op =
|
||||
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
reduce_tuple,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
|
||||
typename FloatC,
|
||||
typename FloatC0,
|
||||
typename FloatC1,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DGridDescriptor_MBlock_MPerBlock,
|
||||
typename ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
@@ -46,15 +46,15 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC0* __restrict__ p_c0_grid,
|
||||
const FloatC1* __restrict__ p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const FloatC0* __restrict__ p_bias_grid,
|
||||
const FloatC1* __restrict__ p_d0_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const C1ElementwiseOperation c1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -63,7 +63,7 @@ __global__ void
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
@@ -72,42 +72,42 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_c0_grid,
|
||||
p_c1_grid,
|
||||
p_ds_grid,
|
||||
p_bias_grid,
|
||||
p_d0_grid,
|
||||
p_reduces_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_c0_grid;
|
||||
ignore = p_c1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_bias_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_reduces_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = reduce_in_element_ops;
|
||||
ignore = reduce_out_element_ops;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = reduce_grid_desc_mblock_mperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
@@ -119,22 +119,22 @@ template <typename FloatAB,
|
||||
typename FloatC0,
|
||||
typename FloatC1,
|
||||
typename FloatReduceAcc,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename ReduceGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
typename C0GridDesc_M_N,
|
||||
typename C1GridDesc_M_N,
|
||||
typename DGridDesc_M,
|
||||
typename ReduceGridDesc_M,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
|
||||
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
|
||||
{
|
||||
const auto M = d_grid_desc_m.GetLength(I0);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
|
||||
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
|
||||
d_grid_desc_m,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
return d_grid_desc_mblock_mperblock;
|
||||
return reduce_grid_desc_mblock_mperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
|
||||
|
||||
using DGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
|
||||
using ReduceGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC0* __restrict__ p_c0_grid,
|
||||
const FloatC1* __restrict__ p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const C1ElementwiseOperation& c1_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC0* __restrict__ p_bias_grid,
|
||||
const FloatC1* __restrict__ p_d0_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const C1ElementwiseOperation& c1_element_op,
|
||||
const ReduceInElementwiseOperations& reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations& reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mperblock =
|
||||
// VGPR reduce_thread_desc_mperblock
|
||||
constexpr auto reduce_thread_desc_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mblock_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock =
|
||||
// VGPR reduce_thread_desc_mblock_mperblock
|
||||
constexpr auto reduce_thread_desc_mblock_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
1,
|
||||
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
|
||||
|
||||
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
|
||||
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto p_d_grid = p_ds_grid[I];
|
||||
auto d_out_element_op = dxs_out_element_op[I];
|
||||
auto p_reduce_grid = p_reduces_grid[I];
|
||||
auto reduce_acc_element_op = reduce_out_element_ops[I];
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
remove_pointer_t<decltype(p_d_grid)>,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
decltype(d_out_element_op),
|
||||
remove_pointer_t<decltype(p_reduce_grid)>,
|
||||
decltype(reduce_thread_desc_mblock_mperblock),
|
||||
decltype(reduce_grid_desc_mblock_mperblock),
|
||||
decltype(reduce_acc_element_op),
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation::At(I),
|
||||
ReduceGlobalMemoryDataOperation::At(I),
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
false>{reduce_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
d_out_element_op};
|
||||
reduce_acc_element_op};
|
||||
},
|
||||
Number<p_ds_grid.Size()>{});
|
||||
Number<p_reduces_grid.Size()>{});
|
||||
|
||||
// c0 and c1
|
||||
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_d_grid = p_ds_grid[In];
|
||||
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_reduce_grid = p_reduces_grid[In];
|
||||
|
||||
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d_thread_buf =
|
||||
auto reduce_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto& d_in_element_op = dxs_in_element_op[In];
|
||||
auto& reduce_in_element_op = reduce_in_element_ops[In];
|
||||
|
||||
auto& d_reduce_thread_copy_vgpr_to_global =
|
||||
dxs_reduce_thread_copy_vgpr_to_global(In);
|
||||
auto& reduce_thread_copy_vgpr_to_global =
|
||||
reduce_tuple_thread_copy_vgpr_to_global(In);
|
||||
|
||||
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
|
||||
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
|
||||
using ThreadwiseReduce =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
DReduceOperation,
|
||||
decltype(reduce_thread_desc_mperblock),
|
||||
ReduceOperation,
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal =
|
||||
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
const auto reduce_identityVal =
|
||||
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d_reduce_thread_copy_vgpr_to_global.Run(
|
||||
d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d_grid_buf);
|
||||
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
reduce_thread_buf,
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
reduce_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -21,16 +21,16 @@ namespace ck {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DGridDescriptor_MBlock_MPerBlock,
|
||||
typename ReduceGridDescriptor_MBlock_MPerBlock,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
@@ -41,17 +41,17 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
@@ -60,32 +60,32 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_reduces_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_reduces_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = reduce_in_element_ops;
|
||||
ignore = reduce_out_element_ops;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = reduce_grid_desc_mblock_mperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
@@ -95,19 +95,19 @@ template <typename FloatAB,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename FloatReduceAcc,
|
||||
typename DPtrsGlobal,
|
||||
typename ReducePtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename ReduceGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
typename DGridDesc_M,
|
||||
typename ReduceGridDesc_M,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
|
||||
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
|
||||
{
|
||||
const auto M = d_grid_desc_m.GetLength(I0);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
|
||||
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
|
||||
d_grid_desc_m,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
return d_grid_desc_mblock_mperblock;
|
||||
return reduce_grid_desc_mblock_mperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
|
||||
using DGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
|
||||
using ReduceGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const ReduceInElementwiseOperations& reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations& reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mperblock =
|
||||
// VGPR reduce_thread_desc_mperblock
|
||||
constexpr auto reduce_thread_desc_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mblock_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock =
|
||||
// VGPR reduce_thread_desc_mblock_mperblock
|
||||
constexpr auto reduce_thread_desc_mblock_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
1,
|
||||
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
|
||||
|
||||
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
|
||||
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto p_d_grid = p_ds_grid[I];
|
||||
auto d_out_element_op = dxs_out_element_op[I];
|
||||
auto p_reduce_grid = p_reduces_grid[I];
|
||||
auto reduce_acc_element_op = reduce_out_element_ops[I];
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
remove_pointer_t<decltype(p_d_grid)>,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
decltype(d_out_element_op),
|
||||
remove_pointer_t<decltype(p_reduce_grid)>,
|
||||
decltype(reduce_thread_desc_mblock_mperblock),
|
||||
decltype(reduce_grid_desc_mblock_mperblock),
|
||||
decltype(reduce_acc_element_op),
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation::At(I),
|
||||
ReduceGlobalMemoryDataOperation::At(I),
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
false>{reduce_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
d_out_element_op};
|
||||
reduce_acc_element_op};
|
||||
},
|
||||
Number<p_ds_grid.Size()>{});
|
||||
Number<p_reduces_grid.Size()>{});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_tuple(I0, I0),
|
||||
c_reduce_thread_buf);
|
||||
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_d_grid = p_ds_grid[In];
|
||||
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_reduce_grid = p_reduces_grid[In];
|
||||
|
||||
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d_thread_buf =
|
||||
auto reduce_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto& d_in_element_op = dxs_in_element_op[In];
|
||||
auto& reduce_in_element_op = reduce_in_element_ops[In];
|
||||
|
||||
auto& d_reduce_thread_copy_vgpr_to_global =
|
||||
dxs_reduce_thread_copy_vgpr_to_global(In);
|
||||
auto& reduce_thread_copy_vgpr_to_global =
|
||||
reduce_tuple_thread_copy_vgpr_to_global(In);
|
||||
|
||||
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
|
||||
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
|
||||
using ThreadwiseReduce =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
DReduceOperation,
|
||||
decltype(reduce_thread_desc_mperblock),
|
||||
ReduceOperation,
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_identityVal =
|
||||
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
const auto reduce_identityVal =
|
||||
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_identityVal; });
|
||||
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d_reduce_thread_copy_vgpr_to_global.Run(
|
||||
d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d_grid_buf);
|
||||
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
reduce_thread_buf,
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
reduce_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
reduce_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user