mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
example for convnd bwd weight bf16 splitk (#265)
* add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight * add bwd weight for bf16: init * remove redundant compute * use datatype and split k to check whether a workspace is used * remove unused computation for work space size * add some code for bfp16 * add device/grid unary op * add unary type convert to bwd-weight example * support bf16 splitk kernel for convnd bwd weight * 1. remove comments. 2. add checkvalidity. 3. add gridsize computation * add workspace size check * fix format * change function name
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "gridwise_unary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -628,6 +629,54 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
1);
|
||||
}
|
||||
|
||||
// type convert descs
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * 4;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
template <index_t Dim>
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using TypeConvertFunctor =
|
||||
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
|
||||
using GridwiseUEltwise =
|
||||
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
@@ -733,6 +782,55 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
true,
|
||||
true>;
|
||||
|
||||
using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
@@ -802,6 +900,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
// init work space
|
||||
p_c_workspace_grid_ = nullptr;
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
@@ -838,6 +939,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
|
||||
// external work space
|
||||
void* p_c_workspace_grid_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -910,41 +1014,159 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
// run kernel for bf16 with splitk
|
||||
const auto run_bf16_splitk = [&](const auto& kernel) {
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_workspace_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(AccDataType)));
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
// kernel for type conversion
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(arg.Conv_K_),
|
||||
static_cast<std::size_t>(arg.Conv_C_)};
|
||||
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(arg.filter_spatial_lengths_),
|
||||
std::end(arg.filter_spatial_lengths_));
|
||||
|
||||
int tensor_size =
|
||||
std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies<int>{});
|
||||
|
||||
const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size);
|
||||
GridDesc_M0 a_grid_desc_m0_ =
|
||||
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
|
||||
GridDesc_M0 b_grid_desc_m0_ =
|
||||
MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256);
|
||||
|
||||
if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting");
|
||||
}
|
||||
|
||||
// run kernel for type conversion
|
||||
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
|
||||
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
|
||||
const auto Run_type_convert = [&](const auto& kernel) {
|
||||
float elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(type_convert_grid_size),
|
||||
dim3(256),
|
||||
0,
|
||||
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
|
||||
p_c_grid_tmp_bf16_,
|
||||
a_grid_desc_m0_,
|
||||
b_grid_desc_m0_,
|
||||
TypeConvertFunctor{});
|
||||
return elapsed_time;
|
||||
};
|
||||
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel_type_convert =
|
||||
kernel_unary_elementwise_1d<GridwiseUEltwise,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
GridDesc_M0,
|
||||
TypeConvertFunctor>;
|
||||
|
||||
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
AccDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
run_bf16_splitk(kernel_conv);
|
||||
ave_time += Run_type_convert(kernel_type_convert);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
AccDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
run_bf16_splitk(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1226,6 +1448,11 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
{
|
||||
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
|
||||
{
|
||||
dynamic_cast<Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "gridwise_unary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t Dim,
|
||||
index_t ScalarPerVector>
|
||||
struct DeviceUnaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor,
|
||||
ScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
BDataType* p_b,
|
||||
const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride_a,
|
||||
const std::vector<index_t>& stride_b,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
shape_(shape),
|
||||
functor_(functor),
|
||||
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
index_t tensor_size =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
BDataType* p_b_;
|
||||
std::vector<int> shape_;
|
||||
GridDesc_M0 a_grid_desc_m0_;
|
||||
GridDesc_M0 b_grid_desc_m0_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.a_grid_desc_m0_,
|
||||
arg.b_grid_desc_m0_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->shape_.back() % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
void* p_b,
|
||||
std::vector<index_t> shape,
|
||||
std::vector<index_t> stride_a,
|
||||
std::vector<index_t> stride_b,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<BDataType*>(p_b),
|
||||
shape,
|
||||
stride_a,
|
||||
stride_b,
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "ScalarPerVector = " << ScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user