mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)
* Implement multiple-reduction in one kernel (kernels, device ops, examples) * Add generic elementwise kernel and device interface * Add generator for normal-distributed data initialization * Add host refer implementation of batchnorm-forward and batchnorm-infer * Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels * Remove un-needed including in batchnorm example * Renaming generic_elementwise to elementiwise in kernel and device classes/functions * Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise * Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise * Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise * Add DeviceElementwiseBase and use it in device_normalize_instance.cpp * Removing and renaming files * Update to synchronize gemm_layernorm client example to the generic element-wise device op API * Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming * Merge two static member functions in device_elementwise.hpp * Remove unary_elementwise_1d kernel and device
This commit is contained in:
@@ -1,353 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename FDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t NDim,
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector,
|
||||
index_t DScalarPerVector,
|
||||
index_t EScalarPerVector,
|
||||
index_t FScalarPerVector>
|
||||
struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
|
||||
using Gridwise5AryEltwise = Gridwise5AryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
FDataType,
|
||||
ComputeDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
DGridDesc_M,
|
||||
EGridDesc_M,
|
||||
FGridDesc_M,
|
||||
ElementwiseFunctor,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector,
|
||||
DScalarPerVector,
|
||||
EScalarPerVector,
|
||||
FScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const CDataType* p_c,
|
||||
const DDataType* p_d,
|
||||
const EDataType* p_e,
|
||||
FDataType* p_f,
|
||||
const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& a_strides,
|
||||
const std::vector<index_t>& b_strides,
|
||||
const std::vector<index_t>& c_strides,
|
||||
const std::vector<index_t>& d_strides,
|
||||
const std::vector<index_t>& e_strides,
|
||||
const std::vector<index_t>& f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
p_d_(p_d),
|
||||
p_e_(p_e),
|
||||
p_f_(p_f),
|
||||
lengths_(lengths),
|
||||
a_strides_(a_strides),
|
||||
b_strides_(b_strides),
|
||||
c_strides_(c_strides),
|
||||
d_strides_(d_strides),
|
||||
e_strides_(e_strides),
|
||||
f_strides_(f_strides),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
|
||||
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
|
||||
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
|
||||
d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_);
|
||||
e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_);
|
||||
f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
const CDataType* p_c_;
|
||||
const DDataType* p_d_;
|
||||
const EDataType* p_e_;
|
||||
FDataType* p_f_;
|
||||
std::vector<index_t> lengths_;
|
||||
AGridDesc_M a_grid_desc_m_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
EGridDesc_M e_grid_desc_m_;
|
||||
FGridDesc_M f_grid_desc_m_;
|
||||
std::vector<index_t> a_strides_;
|
||||
std::vector<index_t> b_strides_;
|
||||
std::vector<index_t> c_strides_;
|
||||
std::vector<index_t> d_strides_;
|
||||
std::vector<index_t> e_strides_;
|
||||
std::vector<index_t> f_strides_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_5ary_elementwise_1d<Gridwise5AryEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
FDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
DGridDesc_M,
|
||||
EGridDesc_M,
|
||||
FGridDesc_M,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.p_d_,
|
||||
arg.p_e_,
|
||||
arg.p_f_,
|
||||
arg.a_grid_desc_m_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.d_grid_desc_m_,
|
||||
arg.e_grid_desc_m_,
|
||||
arg.f_grid_desc_m_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.size() != NDim)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = MPerThread % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
static auto MakeArgument(std::array<const void*, 5> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
std::vector<index_t> d_strides,
|
||||
std::vector<index_t> e_strides,
|
||||
std::vector<index_t> f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<const CDataType*>(p_inputs[2]),
|
||||
static_cast<const DDataType*>(p_inputs[3]),
|
||||
static_cast<const EDataType*>(p_inputs[4]),
|
||||
static_cast<FDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
d_strides,
|
||||
e_strides,
|
||||
f_strides,
|
||||
functor};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, 5> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<const CDataType*>(p_inputs[2]),
|
||||
static_cast<const DDataType*>(p_inputs[3]),
|
||||
static_cast<const EDataType*>(p_inputs[4]),
|
||||
static_cast<FDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
input_strides[0],
|
||||
input_strides[1],
|
||||
input_strides[2],
|
||||
input_strides[3],
|
||||
input_strides[4],
|
||||
output_strides[0],
|
||||
functor);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "Device5aryElementwise"
|
||||
<< "<"
|
||||
<< "NDim = " << NDim
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< "AScalarPerVector = " << AScalarPerVector
|
||||
<< "BScalarPerVector = " << BScalarPerVector
|
||||
<< "CScalarPerVector = " << CScalarPerVector
|
||||
<< "DScalarPerVector = " << DScalarPerVector
|
||||
<< "EScalarPerVector = " << EScalarPerVector
|
||||
<< "FScalarPerVector = " << FScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
void* p_y,
|
||||
double exponentialAverageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance,
|
||||
double epsilon,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormInfer : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
double epsilon,
|
||||
const void* estimatedMean,
|
||||
const void* estimatedInvVariance,
|
||||
void* p_y) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<Rank, NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,247 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t NDim,
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ComputeDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& a_strides,
|
||||
const std::vector<index_t>& b_strides,
|
||||
const std::vector<index_t>& c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
lengths_(lengths),
|
||||
a_strides_(a_strides),
|
||||
b_strides_(b_strides),
|
||||
c_strides_(c_strides),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
|
||||
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
|
||||
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
CDataType* p_c_;
|
||||
std::vector<int> lengths_;
|
||||
AGridDesc_M a_grid_desc_m_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
std::vector<index_t> a_strides_;
|
||||
std::vector<index_t> b_strides_;
|
||||
std::vector<index_t> c_strides_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.a_grid_desc_m_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.size() != NDim)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = MPerThread % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, 2> p_inputs,
|
||||
std::array<void*, 1> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
|
||||
static_cast<const BDataType*>(p_inputs[1]),
|
||||
static_cast<CDataType*>(p_outputs[0]),
|
||||
lengths,
|
||||
input_strides[0],
|
||||
input_strides[1],
|
||||
output_strides[0],
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "NDim = " << NDim
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< "AScalarPerVector = " << AScalarPerVector
|
||||
<< "BScalarPerVector = " << BScalarPerVector
|
||||
<< "CScalarPerVector = " << CScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -538,48 +538,43 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add>;
|
||||
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract>;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
|
||||
using GridwiseBinAdd =
|
||||
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Add,
|
||||
MPerThread,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>>;
|
||||
|
||||
using GridwiseBinSubtract =
|
||||
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Subtract,
|
||||
MPerThread,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>>;
|
||||
|
||||
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
|
||||
Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Add>;
|
||||
|
||||
const auto subtract_kernel =
|
||||
kernel_elementwise_1d<GridwiseBinSubtract,
|
||||
Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Subtract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
@@ -631,18 +626,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Subtract{});
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
|
||||
make_tuple(arg.c_grid_desc_m_),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
|
||||
make_tuple(arg.p_c_grid_real_),
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
@@ -679,18 +674,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
|
||||
make_tuple(arg.c_grid_desc_m_),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
|
||||
make_tuple(arg.p_c_grid_imag_),
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -742,18 +737,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Subtract{});
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
|
||||
make_tuple(arg.c_grid_desc_m_),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
|
||||
make_tuple(arg.p_c_grid_real_),
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
@@ -790,18 +785,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_),
|
||||
make_tuple(arg.c_grid_desc_m_),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid_),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid_)),
|
||||
make_tuple(arg.p_c_grid_imag_),
|
||||
Add{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
|
||||
@@ -2,38 +2,286 @@
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NumInputTensor,
|
||||
ck::index_t NumOutputTensor,
|
||||
index_t NDim,
|
||||
typename ElementwiseFunctor>
|
||||
struct DeviceElementwise : public BaseOperator
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim,
|
||||
index_t MPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwise
|
||||
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
|
||||
std::array<void*, NumOutputTensor> p_outputs,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<std::vector<index_t>> input_strides,
|
||||
std::vector<std::vector<index_t>> output_strides,
|
||||
ElementwiseFunctor functor) = 0;
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
template <ck::index_t NumInputTensor,
|
||||
ck::index_t NumOutputTensor,
|
||||
index_t NDim,
|
||||
typename ElementwiseFunctor>
|
||||
using DeviceElementwisePtr =
|
||||
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_M({1}, {1}, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
};
|
||||
|
||||
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
|
||||
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
MPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
in_grid_1d_desc_tuple_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
lengths, inStridesArray[I.value], gridSize_, blockSize_);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_grid_1d_desc_tuple_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
lengths, outStridesArray[I.value], gridSize_, blockSize_);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
InGrid1dDescTuple in_grid_1d_desc_tuple_;
|
||||
OutGrid1dDescTuple out_grid_1d_desc_tuple_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
|
||||
InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.in_grid_1d_desc_tuple_,
|
||||
arg.out_grid_1d_desc_tuple_,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector) {
|
||||
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
|
||||
return true;
|
||||
|
||||
if(strides.back() != 1 && scalarPerVector == 1)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <array>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim>
|
||||
struct DeviceElementwiseBase : public BaseOperator
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim>
|
||||
using DeviceElementwiseBasePtr = std::unique_ptr<
|
||||
DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t NumReduction,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple>
|
||||
struct DeviceMultipleReduce : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumInputDim = Rank;
|
||||
static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, NumInputDim> inLengths,
|
||||
const std::array<index_t, NumInputDim> inStrides,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t NumReduction,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple>
|
||||
using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
|
||||
NumReduceDim,
|
||||
NumReduction,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,595 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumReduction,
|
||||
typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataTypeTuple,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple,
|
||||
InMemoryDataOperationEnum OutMemoryDataOperation,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
typename OutDstVectorSizeSeq>
|
||||
struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
|
||||
NumReduceDim,
|
||||
NumReduction,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(NumReduction == OutDataTypeTuple::Size() &&
|
||||
NumReduction == InElementwiseOperationTuple::Size() &&
|
||||
NumReduction == AccElementwiseOperationTuple::Size() &&
|
||||
NumReduction == OutDstVectorSizeSeq::Size(),
|
||||
"All tuple should have the same size as the number of Reductions!");
|
||||
|
||||
static_assert(sequence_all_of(OutDstVectorSizeSeq{},
|
||||
[](auto vectorSize) {
|
||||
return (MThreadSliceSize % vectorSize == 0);
|
||||
}),
|
||||
"The OutDstVectorSize should completely divide the MThreadSliceSize!");
|
||||
|
||||
static constexpr bool CheckDataTypeTuple()
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, NumReduction, 1>{}([&](auto I) {
|
||||
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
flag =
|
||||
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
|
||||
OutDataType>::value;
|
||||
});
|
||||
|
||||
return flag;
|
||||
};
|
||||
|
||||
static_assert(CheckDataTypeTuple(),
|
||||
"The OutDataType must support the specified OutMemoryDataOperation!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumInputDim = Rank;
|
||||
static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
|
||||
// later
|
||||
static constexpr bool use_multiblock =
|
||||
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
|
||||
|
||||
static_assert(
|
||||
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
|
||||
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
};
|
||||
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
|
||||
const std::array<index_t, NumInputDim>& inStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<index_t, NumOutputDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto GenerateOutGrid1dDescTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
(void)I;
|
||||
return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
|
||||
std::array<index_t, NumOutputDim>{});
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
};
|
||||
|
||||
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(
|
||||
std::array<index_t, NumInputDim>{}, std::array<index_t, NumInputDim>{}, 1, 1));
|
||||
using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple());
|
||||
|
||||
static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<index_t, NumOutputDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto length = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
|
||||
|
||||
auto out_grid_desc_m_padded =
|
||||
transform_tensor_descriptor(out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(length, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto GenerateOutGrid1dDescTuple_2()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
(void)I;
|
||||
return MakeDst1dDescriptorForBufferSet(std::array<index_t, NumOutputDim>{},
|
||||
std::array<index_t, NumOutputDim>{});
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
};
|
||||
|
||||
using OutGridDesc_M_Tuple_2 = decltype(GenerateOutGrid1dDescTuple_2());
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumInputDim>& inLengths,
|
||||
const std::array<index_t, NumInputDim>& inStrides,
|
||||
const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
|
||||
const std::array<int, NumReduceDim>& reduceDims,
|
||||
const std::array<const void*, NumReduction>& alphas,
|
||||
const std::array<const void*, NumReduction>& betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction>& out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple)
|
||||
: outLengths_{outLengths},
|
||||
outStridesArray_{outStridesArray},
|
||||
in_elementwise_op_tuple_{in_elementwise_op_tuple},
|
||||
acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
|
||||
{
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
for(size_t i = 0; i < NumReduction; i++)
|
||||
{
|
||||
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
|
||||
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
|
||||
};
|
||||
|
||||
in_dev_ = static_cast<const InDataType*>(in_dev);
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto iR) {
|
||||
using OutDataTypePointer =
|
||||
remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
|
||||
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
|
||||
return static_cast<OutDataType*>(out_dev_buffers[iR]);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize =
|
||||
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize = 1;
|
||||
numBlockTileIteration =
|
||||
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
in_grid_desc_m_k =
|
||||
MakeSrc2dDescriptor(inLengths_, inStrides_, blkGroupSize, numBlockTileIteration);
|
||||
|
||||
out_grid_desc_m_tuple = generate_tuple(
|
||||
[&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
|
||||
Number<NumReduction>{});
|
||||
|
||||
out_grid_desc_m_tuple_2 = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize * blkGroupSize;
|
||||
|
||||
gridSize_pre =
|
||||
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
|
||||
}
|
||||
|
||||
std::array<index_t, NumInputDim> inLengths_;
|
||||
std::array<index_t, NumInputDim> inStrides_;
|
||||
|
||||
std::array<index_t, NumOutputDim> outLengths_;
|
||||
std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
|
||||
|
||||
Array<AccDataType, NumReduction> alpha_values_;
|
||||
Array<AccDataType, NumReduction> beta_values_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
InGridDesc_M_K in_grid_desc_m_k;
|
||||
OutGridDesc_M_Tuple out_grid_desc_m_tuple;
|
||||
OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2;
|
||||
|
||||
InElementwiseOperationTuple in_elementwise_op_tuple_;
|
||||
AccElementwiseOperationTuple acc_elementwise_op_tuple_;
|
||||
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
int blkGroupSize;
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
|
||||
size_t gridSize_pre;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
using GridwiseMultipleReduce =
|
||||
GridwiseMultipleReduction_mk_to_m_multiblock<NumReduction,
|
||||
InDataType,
|
||||
OutDataTypePointerTuple,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M_Tuple,
|
||||
ReduceOperation,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple,
|
||||
OutMemoryDataOperation,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSizeSeq>;
|
||||
|
||||
const auto kernel_main =
|
||||
kernel_multiple_reduce_multiblock<GridwiseMultipleReduce,
|
||||
NumReduction,
|
||||
InDataType,
|
||||
OutDataTypePointerTuple,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M_Tuple,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
auto identity_values = generate_tuple(
|
||||
[&](auto iR) {
|
||||
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[iR])>;
|
||||
return ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
|
||||
OutMemoryDataOperation);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
|
||||
const auto kernel_pre = kernel_multiple_buffer_set_value<OutGridDesc_M_Tuple_2,
|
||||
NumReduction,
|
||||
BlockSize,
|
||||
OutDataTypePointerTuple,
|
||||
OutDataTypeTuple>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.out_grid_desc_m_tuple_2,
|
||||
arg.out_dev_buffers_,
|
||||
identity_values);
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_m_k,
|
||||
arg.out_grid_desc_m_tuple,
|
||||
arg.in_elementwise_op_tuple_,
|
||||
arg.acc_elementwise_op_tuple_,
|
||||
arg.blkGroupSize,
|
||||
arg.numBlockTileIteration,
|
||||
arg.alpha_values_,
|
||||
arg.in_dev_,
|
||||
arg.beta_values_,
|
||||
arg.out_dev_buffers_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
for(size_t i = 0; i < pArg->beta_values_.Size(); i++)
|
||||
if(pArg->beta_values_[i] != 0.0f)
|
||||
return (false);
|
||||
};
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
// To improve
|
||||
bool valid = true;
|
||||
static_for<0, NumReduction, 1>{}([&](auto I) {
|
||||
if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
|
||||
OutDstVectorSizeSeq::At(I) != 1)
|
||||
valid = false;
|
||||
|
||||
if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
|
||||
valid = false;
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
return (false);
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
// blkGroupSize of 1 should be handled by Blockwise path using
|
||||
// InMemoryDataOperationEnum::Set
|
||||
if(pArg->blkGroupSize == 1)
|
||||
return (false);
|
||||
|
||||
// This is very strong restriction, but needed to avoid some failure
|
||||
if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0)
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
// cases with very small reduce_total_length should be handled by ThreadWise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize < 2)
|
||||
return (false);
|
||||
};
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, NumInputDim> inLengths,
|
||||
const std::array<index_t, NumInputDim> inStrides,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStridesArray,
|
||||
reduceDims,
|
||||
alphas,
|
||||
betas,
|
||||
in_dev,
|
||||
out_dev_buffers,
|
||||
in_elementwise_op_tuple,
|
||||
acc_elementwise_op_tuple);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
|
||||
str << "OutDstVectorSize";
|
||||
static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumReduction,
|
||||
typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataTypeTuple,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple,
|
||||
bool PropagateNan,
|
||||
index_t BlockSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
typename OutDstVectorSizeSeq>
|
||||
struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
|
||||
NumReduceDim,
|
||||
NumReduction,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
|
||||
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(NumReduction == OutDataTypeTuple::Size() &&
|
||||
NumReduction == InElementwiseOperationTuple::Size() &&
|
||||
NumReduction == AccElementwiseOperationTuple::Size() &&
|
||||
NumReduction == OutDstVectorSizeSeq::Size(),
|
||||
"All tuple should have the same size as the number of Reductions!");
|
||||
|
||||
static_assert(sequence_all_of(OutDstVectorSizeSeq{},
|
||||
[](auto vectorSize) {
|
||||
return (MThreadSliceSize % vectorSize == 0);
|
||||
}),
|
||||
"The OutDstVectorSize should completely divide the MThreadSliceSize!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumInputDim = Rank;
|
||||
static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
};
|
||||
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
|
||||
const std::array<index_t, NumInputDim>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<index_t, NumOutputDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto GenerateOutGrid1dDescTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
(void)I;
|
||||
return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
|
||||
std::array<index_t, NumOutputDim>{});
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
};
|
||||
|
||||
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array<index_t, NumInputDim>{},
|
||||
std::array<index_t, NumInputDim>{}));
|
||||
using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple());
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumInputDim>& inLengths,
|
||||
const std::array<index_t, NumInputDim>& inStrides,
|
||||
const std::array<index_t, NumOutputDim>& outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
|
||||
const std::array<int, NumReduceDim>& reduceDims,
|
||||
const std::array<const void*, NumReduction>& alphas,
|
||||
const std::array<const void*, NumReduction>& betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction>& out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple)
|
||||
: outLengths_{outLengths},
|
||||
outStridesArray_{outStridesArray},
|
||||
in_elementwise_op_tuple_{in_elementwise_op_tuple},
|
||||
acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
|
||||
{
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
for(size_t i = 0; i < NumReduction; i++)
|
||||
{
|
||||
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
|
||||
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
|
||||
};
|
||||
|
||||
in_dev_ = static_cast<const InDataType*>(in_dev);
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto iR) {
|
||||
using OutDataTypePointer =
|
||||
remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
|
||||
using OutDataType = remove_cvref_t<remove_pointer_t<OutDataTypePointer>>;
|
||||
return static_cast<OutDataType*>(out_dev_buffers[iR]);
|
||||
},
|
||||
Number<NumReduction>{});
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
in_grid_desc_m_k = MakeSrc2dDescriptor(inLengths_, inStrides_);
|
||||
|
||||
out_grid_desc_m_tuple = generate_tuple(
|
||||
[&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
|
||||
Number<NumReduction>{});
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
}
|
||||
|
||||
std::array<index_t, NumInputDim> inLengths_;
|
||||
std::array<index_t, NumInputDim> inStrides_;
|
||||
|
||||
std::array<index_t, NumOutputDim> outLengths_;
|
||||
std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
|
||||
|
||||
Array<AccDataType, NumReduction> alpha_values_;
|
||||
Array<AccDataType, NumReduction> beta_values_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
InGridDesc_M_K in_grid_desc_m_k;
|
||||
OutGridDesc_M_Tuple out_grid_desc_m_tuple;
|
||||
|
||||
InElementwiseOperationTuple in_elementwise_op_tuple_;
|
||||
AccElementwiseOperationTuple acc_elementwise_op_tuple_;
|
||||
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
using GridwiseMultipleReduce =
|
||||
GridwiseMultipleReduction_mk_to_m_threadwise<NumReduction,
|
||||
InDataType,
|
||||
OutDataTypePointerTuple,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M_Tuple,
|
||||
ReduceOperation,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
BlockSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSizeSeq>;
|
||||
|
||||
const auto kernel_main =
|
||||
kernel_multiple_reduce_threadwise<GridwiseMultipleReduce,
|
||||
NumReduction,
|
||||
InDataType,
|
||||
OutDataTypePointerTuple,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
OutGridDesc_M_Tuple,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>;
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_m_k,
|
||||
arg.out_grid_desc_m_tuple,
|
||||
arg.in_elementwise_op_tuple_,
|
||||
arg.acc_elementwise_op_tuple_,
|
||||
arg.alpha_values_,
|
||||
arg.in_dev_,
|
||||
arg.beta_values_,
|
||||
arg.out_dev_buffers_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// To improve
|
||||
bool valid = true;
|
||||
static_for<0, NumReduction, 1>{}([&](auto I) {
|
||||
if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
|
||||
OutDstVectorSizeSeq::At(I) != 1)
|
||||
valid = false;
|
||||
|
||||
if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
|
||||
valid = false;
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, NumInputDim> inLengths,
|
||||
const std::array<index_t, NumInputDim> inStrides,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<const void*, NumReduction> alphas,
|
||||
const std::array<const void*, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
outLengths,
|
||||
outStridesArray,
|
||||
reduceDims,
|
||||
alphas,
|
||||
betas,
|
||||
in_dev,
|
||||
out_dev_buffers,
|
||||
in_elementwise_op_tuple,
|
||||
acc_elementwise_op_tuple);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMultipleReduceThreadwise<" << BlockSize << ",";
|
||||
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
|
||||
str << "OutDstVectorSize";
|
||||
static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -35,6 +35,25 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
|
||||
return std::make_pair(invariant_total_length, reduce_total_length);
|
||||
};
|
||||
|
||||
template <index_t Rank, int NumReduceDim>
|
||||
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
|
||||
{
|
||||
static_assert(Rank <= 6, "bigger Rank size not supported!");
|
||||
|
||||
long_index_t invariant_total_length = 1;
|
||||
long_index_t reduce_total_length = 1;
|
||||
|
||||
constexpr int NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
for(int i = NumInvariantDim; i < Rank; i++)
|
||||
reduce_total_length *= inLengths[i];
|
||||
|
||||
for(int i = 0; i < NumInvariantDim; i++)
|
||||
invariant_total_length *= inLengths[i];
|
||||
|
||||
return std::make_pair(invariant_total_length, reduce_total_length);
|
||||
};
|
||||
|
||||
// helper functions using variadic template arguments
|
||||
template <index_t... Ns>
|
||||
auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
|
||||
@@ -85,6 +104,39 @@ std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origL
|
||||
return newLengthsStrides;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumReduceDim>
|
||||
std::array<index_t, Rank>
|
||||
shuffle_tensor_dimensions(const std::array<index_t, Rank>& origLengthsStrides,
|
||||
const std::array<int, NumReduceDim>& reduceDims)
|
||||
{
|
||||
std::array<index_t, Rank> newLengthsStrides;
|
||||
|
||||
int reduceFlag = 0;
|
||||
|
||||
// flag the bits for the reduceDims
|
||||
for(int i = 0; i < NumReduceDim; i++)
|
||||
{
|
||||
reduceFlag |= 1 << reduceDims[i];
|
||||
};
|
||||
|
||||
// collect invariant dimensions
|
||||
int pos = 0;
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) == 0)
|
||||
{
|
||||
newLengthsStrides[pos++] = origLengthsStrides[i];
|
||||
};
|
||||
|
||||
// collect reduce dimensions
|
||||
for(int i = 0; i < Rank; i++)
|
||||
if((reduceFlag & (1 << i)) > 0)
|
||||
{
|
||||
newLengthsStrides[pos++] = origLengthsStrides[i];
|
||||
};
|
||||
|
||||
return newLengthsStrides;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t Dim,
|
||||
index_t ScalarPerVector>
|
||||
struct DeviceUnaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M0>
|
||||
static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m0 = desc_m0.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * ScalarPerVector;
|
||||
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
|
||||
const auto desc_m0_pad =
|
||||
transform_tensor_descriptor(desc_m0,
|
||||
make_tuple(make_right_pad_transform(m0, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m0_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<Dim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<Dim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(Dim > 1)
|
||||
{
|
||||
const auto desc_m0 = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<Dim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor,
|
||||
ScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
BDataType* p_b,
|
||||
const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride_a,
|
||||
const std::vector<index_t>& stride_b,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
shape_(shape),
|
||||
functor_(functor),
|
||||
blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
index_t tensor_size =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||
gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size);
|
||||
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_);
|
||||
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
BDataType* p_b_;
|
||||
std::vector<int> shape_;
|
||||
GridDesc_M0 a_grid_desc_m0_;
|
||||
GridDesc_M0 b_grid_desc_m0_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_unary_elementwise_1d<GridwiseUEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GridDesc_M0,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.a_grid_desc_m0_,
|
||||
arg.b_grid_desc_m0_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->shape_.back() % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
void* p_b,
|
||||
std::vector<index_t> shape,
|
||||
std::vector<index_t> stride_a,
|
||||
std::vector<index_t> stride_b,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<BDataType*>(p_b),
|
||||
shape,
|
||||
stride_a,
|
||||
stride_b,
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "ScalarPerVector = " << ScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user