mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Hotfix eltiwseop (#242)
* Use vector constructor instead * Fix typo * Move blockSize to the MakeArgumentPointer * Fix naming * Fix clang format * remove blockSize from DeviceBinaryElementwise::Argument() Co-authored-by: rocking <chunylai@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -19,8 +19,6 @@ template <typename ADataType,
|
||||
index_t ScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
DeviceBinaryElementwise(index_t blockSize = 256) : BaseOperator(), blockSize_(blockSize) {}
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
@@ -81,18 +79,18 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
const std::vector<index_t>& stride_a,
|
||||
const std::vector<index_t>& stride_b,
|
||||
const std::vector<index_t>& stride_c,
|
||||
ElementwiseFunctor functor,
|
||||
index_t blockSize)
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
shape_(shape),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize);
|
||||
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize);
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
|
||||
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
@@ -103,13 +101,12 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
GridDesc_M0 b_grid_desc_m0_;
|
||||
GridDesc_M0 c_grid_desc_m0_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
Invoker(index_t blockSize) : BaseInvoker(), blockSize_(blockSize) {}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
|
||||
@@ -122,7 +119,7 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(blockSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
@@ -140,8 +137,6 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
@@ -173,14 +168,10 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
functor,
|
||||
blockSize_);
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{blockSize_});
|
||||
}
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
@@ -195,8 +186,6 @@ struct DeviceBinaryElementwise : public BaseOperator
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
Reference in New Issue
Block a user