Hotfix eltiwseop (#242)

* Use vector constructor instead

* Fix typo

* Move blockSize to the MakeArgumentPointer

* Fix naming

* Fix clang format

* remove blockSize from DeviceBinaryElementwise::Argument()

Co-authored-by: rocking <chunylai@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
rocking5566
2022-05-20 11:02:06 +08:00
committed by GitHub
parent 0ffe956ab1
commit bb4b82a95a
5 changed files with 32 additions and 63 deletions

View File

@@ -19,8 +19,6 @@ template <typename ADataType,
index_t ScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
{
DeviceBinaryElementwise(index_t blockSize = 256) : BaseOperator(), blockSize_(blockSize) {}
static constexpr auto I0 = Number<0>{};
template <typename Desc_M0>
@@ -81,18 +79,18 @@ struct DeviceBinaryElementwise : public BaseOperator
const std::vector<index_t>& stride_a,
const std::vector<index_t>& stride_b,
const std::vector<index_t>& stride_c,
ElementwiseFunctor functor,
index_t blockSize)
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
shape_(shape),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize);
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
}
const ADataType* p_a_;
@@ -103,13 +101,12 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
Invoker(index_t blockSize) : BaseInvoker(), blockSize_(blockSize) {}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
@@ -122,7 +119,7 @@ struct DeviceBinaryElementwise : public BaseOperator
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(blockSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
@@ -140,8 +137,6 @@ struct DeviceBinaryElementwise : public BaseOperator
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
index_t blockSize_;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
@@ -173,14 +168,10 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_a,
stride_b,
stride_c,
functor,
blockSize_);
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{blockSize_});
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::string GetTypeString() const override
{
@@ -195,8 +186,6 @@ struct DeviceBinaryElementwise : public BaseOperator
return str.str();
}
index_t blockSize_;
};
} // namespace device