mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Maxpool bwd (#750)
* Add maxpool f32 kernel and example
* Revise copyright
* Add device pool bwd device op
* Support f16 and bf16
* Add compute datatype for reference code.
Prevent error in bf16
* Fix type error
* Remove layout
* Fix bf16 error
* Add f16 and bf16 example
* Add more operations
* Implement IsSupportedArgument
* Add changelog
* Add comment
* Add comment
* Remove useless header
* Move initialize of workspace to the run
* Move set din zero to the device operator
* Save din_length_raw
* Remove useless header
* Calculate gridsize according to the number of CU
* Calculate gridSize according to the number of CU.
Remove useless header
* Add put example
* Remove useless header
* Fix CI fail
[ROCm/composable_kernel commit: 341ad95665]
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
static int getAvailableComputeUnitCount(const StreamConfig& stream_config)
|
||||
static inline int getAvailableComputeUnitCount(const StreamConfig& stream_config)
|
||||
{
|
||||
constexpr int MAX_MASK_DWORDS = 64;
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
|
||||
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
|
||||
struct DeviceIndexPoolBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
const void* p_indices,
|
||||
void* p_din,
|
||||
index_t dout_length,
|
||||
index_t din_length,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// output[indices] = input
|
||||
template <typename InDataType,
|
||||
typename IndexDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum Op>
|
||||
struct DevicePutElement : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_input,
|
||||
const void* p_indices,
|
||||
void* p_output,
|
||||
index_t input_length,
|
||||
index_t output_length,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// output[indices] = input
|
||||
template <typename DOutDataType,
|
||||
typename IndexDataType,
|
||||
typename DInDataType,
|
||||
ck::index_t InOutVectorSize>
|
||||
struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType>
|
||||
{
|
||||
using DInDataType_AutomicAddPreCast =
|
||||
conditional_t<is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
|
||||
DInDataType,
|
||||
float>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step)
|
||||
{
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(index_t length, index_t loop_step)
|
||||
{
|
||||
const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
|
||||
return PadDescriptor_M_1d(desc_m, loop_step);
|
||||
}
|
||||
|
||||
using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
|
||||
|
||||
using GridwisePutElementSet = GridwisePutElement_1D<InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType,
|
||||
PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InOutVectorSize>;
|
||||
|
||||
using GridwisePutElementAtomicAdd = GridwisePutElement_1D<InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType_AutomicAddPreCast,
|
||||
PassThrough,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
InOutVectorSize>;
|
||||
|
||||
using GridwiseCasting = GridwiseElementwise_1D<Tuple<InOutGrid1dDesc>,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
UnaryConvert,
|
||||
InOutVectorSize,
|
||||
Sequence<InOutVectorSize>,
|
||||
Sequence<InOutVectorSize>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const DOutDataType* p_dout,
|
||||
const IndexDataType* p_indices,
|
||||
DInDataType* p_din,
|
||||
index_t dout_length,
|
||||
index_t din_length,
|
||||
const std::vector<ck::index_t>& window_lengths,
|
||||
const std::vector<ck::index_t>& window_strides)
|
||||
: p_dout_{p_dout},
|
||||
p_indices_{p_indices},
|
||||
p_din_{p_din},
|
||||
dout_length_raw_{dout_length},
|
||||
din_length_raw_{din_length},
|
||||
blockSize_{256},
|
||||
windowOverlap_{false}
|
||||
{
|
||||
for(size_t i = 0; i < window_lengths.size(); ++i)
|
||||
{
|
||||
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
const DOutDataType* p_dout_;
|
||||
const IndexDataType* p_indices_;
|
||||
DInDataType* p_din_;
|
||||
index_t dout_length_raw_;
|
||||
index_t din_length_raw_;
|
||||
index_t blockSize_;
|
||||
bool windowOverlap_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config);
|
||||
index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize;
|
||||
InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step);
|
||||
InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step);
|
||||
|
||||
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(arg.p_din_,
|
||||
0,
|
||||
arg.din_length_raw_ * sizeof(DInDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
if(arg.windowOverlap_)
|
||||
{
|
||||
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
|
||||
InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType,
|
||||
PassThrough>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
put_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
dout_grid_desc,
|
||||
arg.p_dout_,
|
||||
arg.p_indices_,
|
||||
arg.p_din_,
|
||||
PassThrough{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
|
||||
InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType,
|
||||
PassThrough>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
put_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
dout_grid_desc,
|
||||
arg.p_dout_,
|
||||
arg.p_indices_,
|
||||
arg.p_din_,
|
||||
PassThrough{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.windowOverlap_)
|
||||
{
|
||||
if(arg.p_workspace_ == nullptr)
|
||||
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
|
||||
|
||||
hip_check_error(
|
||||
hipMemsetAsync(arg.p_workspace_,
|
||||
0,
|
||||
arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast),
|
||||
stream_config.stream_id_));
|
||||
|
||||
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
|
||||
InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType_AutomicAddPreCast,
|
||||
PassThrough>;
|
||||
|
||||
const auto cast_kernel =
|
||||
kernel_elementwise_1d<GridwiseCasting,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
UnaryConvert>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
put_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
dout_grid_desc,
|
||||
arg.p_dout_,
|
||||
arg.p_indices_,
|
||||
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
|
||||
PassThrough{});
|
||||
|
||||
elapsed_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
cast_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
ck::make_tuple(din_grid_desc),
|
||||
ck::make_tuple(din_grid_desc),
|
||||
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
|
||||
arg.p_din_,
|
||||
UnaryConvert{});
|
||||
|
||||
return elapsed_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
|
||||
InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
IndexDataType,
|
||||
DInDataType,
|
||||
PassThrough>;
|
||||
|
||||
hip_check_error(hipMemsetAsync(arg.p_din_,
|
||||
0,
|
||||
arg.din_length_raw_ * sizeof(DInDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
put_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
dout_grid_desc,
|
||||
arg.p_dout_,
|
||||
arg.p_indices_,
|
||||
arg.p_din_,
|
||||
PassThrough{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
bool needCast = pArg_->windowOverlap_ &&
|
||||
!(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>);
|
||||
|
||||
if(!needCast)
|
||||
return 0;
|
||||
else
|
||||
return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast);
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(pArg->din_length_raw_ % InOutVectorSize != 0 ||
|
||||
pArg->dout_length_raw_ % InOutVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
const void* p_indices,
|
||||
void* p_din,
|
||||
index_t dout_length,
|
||||
index_t din_length,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides) override
|
||||
{
|
||||
// Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
|
||||
// physical size of the packed tensor
|
||||
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
|
||||
static_cast<const IndexDataType*>(p_indices),
|
||||
static_cast<DInDataType*>(p_din),
|
||||
dout_length,
|
||||
din_length,
|
||||
window_lengths,
|
||||
window_strides);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_put_element.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// output[indices] = input
|
||||
template <typename InDataType,
|
||||
typename IndexDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum MemOp,
|
||||
ck::index_t InVectorSize>
|
||||
struct DevicePutElementImpl
|
||||
: public DevicePutElement<InDataType, IndexDataType, OutDataType, ElementwiseOperation, MemOp>
|
||||
{
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * InVectorSize;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1));
|
||||
|
||||
using GridwisePutElement = GridwisePutElement_1D<InGrid1dDesc,
|
||||
InDataType,
|
||||
IndexDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation,
|
||||
MemOp,
|
||||
InVectorSize>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_input,
|
||||
const IndexDataType* p_indices,
|
||||
OutDataType* p_output,
|
||||
index_t input_length,
|
||||
ElementwiseOperation elementwise_op)
|
||||
: p_input_{p_input},
|
||||
p_indices_{p_indices},
|
||||
p_output_{p_output},
|
||||
input_length_raw_{input_length},
|
||||
elementwise_op_{elementwise_op},
|
||||
blockSize_{256}
|
||||
{
|
||||
}
|
||||
|
||||
const InDataType* p_input_;
|
||||
const IndexDataType* p_indices_;
|
||||
OutDataType* p_output_;
|
||||
index_t input_length_raw_;
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config);
|
||||
InGrid1dDesc in_grid_desc =
|
||||
MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_);
|
||||
|
||||
const auto kernel = kernel_put_element_1d<GridwisePutElement,
|
||||
InGrid1dDesc,
|
||||
InDataType,
|
||||
IndexDataType,
|
||||
OutDataType,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
in_grid_desc,
|
||||
arg.p_input_,
|
||||
arg.p_indices_,
|
||||
arg.p_output_,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg->input_length_raw_ % InVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_input,
|
||||
const void* p_indices,
|
||||
void* p_output,
|
||||
index_t input_length,
|
||||
index_t,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_input),
|
||||
static_cast<const IndexDataType*>(p_indices),
|
||||
static_cast<OutDataType*>(p_output),
|
||||
input_length,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
155
include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp
Normal file
155
include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp
Normal file
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwisePutElementwise1dFunctor,
|
||||
typename InGrid1dDesc,
|
||||
typename InDataType,
|
||||
typename IndexDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
|
||||
const InDataType* __restrict__ p_in_global,
|
||||
const IndexDataType* __restrict__ p_indices_global,
|
||||
OutDataType* __restrict__ p_out_global,
|
||||
const ElementwiseOperation elementwise_op)
|
||||
{
|
||||
GridwisePutElementwise1dFunctor::Run(
|
||||
in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
|
||||
}
|
||||
|
||||
// output[indices] = input
|
||||
template <typename InGrid1dDesc,
|
||||
typename InDataType,
|
||||
typename IndexDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum MemOp,
|
||||
index_t InVectorSize>
|
||||
struct GridwisePutElement_1D
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<InVectorSize>{}));
|
||||
|
||||
__device__ static void Run(const InGrid1dDesc& in_grid_1d_desc,
|
||||
const InDataType* __restrict__ p_in_global,
|
||||
const IndexDataType* __restrict__ p_indices_global,
|
||||
OutDataType* __restrict__ p_out_global,
|
||||
const ElementwiseOperation& elementwise_op)
|
||||
{
|
||||
// Global Memory
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global, in_grid_1d_desc.GetElementSpaceSize());
|
||||
|
||||
const auto indices_global_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_indices_global,
|
||||
in_grid_1d_desc.GetElementSpaceSize(),
|
||||
NumericLimits<IndexDataType>::Lowest());
|
||||
|
||||
// VGPR
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, InDataType, InVectorSize, true> in_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, InVectorSize, true> indices_thread_buf;
|
||||
|
||||
// Thread id, Block id and index
|
||||
const index_t thread_global_id = get_thread_global_1d_id();
|
||||
const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize);
|
||||
const index_t blockSize = get_block_size();
|
||||
const index_t blockPerGrid = get_grid_size();
|
||||
const auto M = in_grid_1d_desc.GetLength(I0);
|
||||
const index_t loop_step = blockPerGrid * blockSize * InVectorSize;
|
||||
const auto loop_step_index = make_multi_index(loop_step);
|
||||
|
||||
auto in_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
InDataType,
|
||||
decltype(in_grid_1d_desc),
|
||||
decltype(thread_buffer_desc_m),
|
||||
Sequence<InVectorSize>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
InVectorSize, // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{in_grid_1d_desc, thread_global_offset};
|
||||
|
||||
auto indices_global_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
|
||||
IndexDataType,
|
||||
decltype(in_grid_1d_desc),
|
||||
decltype(thread_buffer_desc_m),
|
||||
Sequence<InVectorSize>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
InVectorSize, // ScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
false>{in_grid_1d_desc, thread_global_offset};
|
||||
|
||||
index_t num_iter = M / loop_step;
|
||||
do
|
||||
{
|
||||
in_global_load.Run(in_grid_1d_desc,
|
||||
in_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
in_thread_buf);
|
||||
|
||||
in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
|
||||
|
||||
static_for<0, InVectorSize, 1>{}(
|
||||
[&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); });
|
||||
|
||||
indices_global_load.Run(in_grid_1d_desc,
|
||||
indices_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
indices_thread_buf);
|
||||
|
||||
indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
|
||||
|
||||
static_for<0, InVectorSize, 1>{}([&](auto iM) {
|
||||
if(indices_thread_buf[iM] >= 0)
|
||||
{
|
||||
if constexpr(MemOp == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
// User should guarantee each index in p_indices_global is different
|
||||
*(p_out_global + indices_thread_buf[iM]) =
|
||||
ck::type_convert<OutDataType>(in_thread_buf[iM]);
|
||||
}
|
||||
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd)
|
||||
{
|
||||
atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
|
||||
ck::type_convert<OutDataType>(in_thread_buf[iM]));
|
||||
}
|
||||
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
|
||||
{
|
||||
atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
|
||||
ck::type_convert<OutDataType>(in_thread_buf[iM]));
|
||||
}
|
||||
else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
|
||||
{
|
||||
// User should guarantee each index in p_indices_global is different
|
||||
*(p_out_global + indices_thread_buf[iM]) +=
|
||||
ck::type_convert<OutDataType>(in_thread_buf[iM]);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(MemOp == InMemoryDataOperationEnum::Set ||
|
||||
MemOp == InMemoryDataOperationEnum::AtomicAdd ||
|
||||
MemOp == InMemoryDataOperationEnum::AtomicMax ||
|
||||
MemOp == InMemoryDataOperationEnum::Add);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
} while(--num_iter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user