mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Refactor elementwise kernels (#1222)
* Refactor elementwise kernels * Instances fixes * Fix cmake * Fix max pool bwd test * Update two stage gemm split k * Restore elementwise scale for hiptensor backward compatiblity * Fix Acc data type check in conv fwd multiple abd * Disable conv fp64 fwd example * Update grouped conv weight multi d
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto MPerThread = Number<4>{};
|
||||
static constexpr index_t MPerThread =
|
||||
MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t NPerThread =
|
||||
NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto AScalarPerVector = Number<4>{};
|
||||
static constexpr auto BScalarPerVector = Number<4>{};
|
||||
static constexpr auto CScalarPerVector = Number<4>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
template <typename Desc_M_N>
|
||||
static auto PadDescriptor_M_N(Desc_M_N desc)
|
||||
{
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
const auto M = desc.GetLength(I0);
|
||||
const auto N = desc.GetLength(I1);
|
||||
const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
|
||||
const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return padded_desc;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
return PadDescriptor_M_N(desc);
|
||||
}
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
|
||||
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
|
||||
@@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
p_c_grid_imag{p_c_grid_imag_},
|
||||
p_aux_grid{p_workspace}
|
||||
{
|
||||
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m =
|
||||
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
|
||||
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m =
|
||||
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
|
||||
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
|
||||
}
|
||||
|
||||
p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
|
||||
@@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
CDataType* p_c_grid_imag;
|
||||
CDataType* p_aux_grid;
|
||||
CDataType* p_aux_2_grid;
|
||||
CGridDesc_M c_grid_desc_m;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -248,40 +243,63 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
|
||||
using GridwiseBinAdd =
|
||||
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Add,
|
||||
MPerThread,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>>;
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseBinAdd = GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Add,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
Sequence<0, 1>,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
using GridwiseBinSubtract =
|
||||
GridwiseElementwise_1D<Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Subtract,
|
||||
MPerThread,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>>;
|
||||
GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Subtract,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
Sequence<0, 1>,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
const auto add_kernel = kernel_elementwise_1d<GridwiseBinAdd,
|
||||
Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Add>;
|
||||
const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
|
||||
const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
|
||||
const auto block_2_tile_map = Block2TileMap(M, N);
|
||||
|
||||
const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
|
||||
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Add>;
|
||||
|
||||
const auto subtract_kernel =
|
||||
kernel_elementwise_1d<GridwiseBinSubtract,
|
||||
Tuple<CGridDesc_M, CGridDesc_M>,
|
||||
Tuple<CGridDesc_M>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Subtract>;
|
||||
kernel_elementwise<GridwiseBinSubtract,
|
||||
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Subtract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
@@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_real),
|
||||
block_2_tile_map,
|
||||
Subtract{});
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
@@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_imag),
|
||||
block_2_tile_map,
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
@@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_real),
|
||||
block_2_tile_map,
|
||||
Subtract{});
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
@@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m),
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_imag),
|
||||
block_2_tile_map,
|
||||
Add{});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,338 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim_m,
|
||||
index_t NumDim_n,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
NumDim_m + NumDim_n>
|
||||
{
|
||||
static constexpr index_t NumDim = NumDim_m + NumDim_n;
|
||||
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_MN>
|
||||
static auto PadDescriptor_MN_2d(Desc_MN desc_mn,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n)
|
||||
{
|
||||
std::ignore = blockSize;
|
||||
std::ignore = gridSize;
|
||||
const auto m = desc_mn.GetLength(I0);
|
||||
const auto n = desc_mn.GetLength(I1);
|
||||
const index_t loop_step_m = num_threads_m * MPerThread;
|
||||
const index_t loop_step_n = num_threads_n * NPerThread;
|
||||
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
|
||||
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
|
||||
|
||||
const auto desc_mn_pad = transform_tensor_descriptor(
|
||||
desc_mn,
|
||||
make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return desc_mn_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_MN(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
|
||||
|
||||
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
|
||||
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
|
||||
|
||||
// merge nd to 2d desc - [s0 * s1 * ...]
|
||||
|
||||
if constexpr(NumDim > 2)
|
||||
{
|
||||
const auto desc_mn = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid2dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 2)
|
||||
{
|
||||
return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
};
|
||||
|
||||
using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumOutput>{}));
|
||||
using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number<NumInput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_2D<InGrid2dDescTuple,
|
||||
OutGrid2dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
blockSize_(256)
|
||||
{
|
||||
static_assert(NumDim_m > 0, "");
|
||||
static_assert(NumDim_n > 0, "");
|
||||
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config);
|
||||
index_t num_threads_m = (gridSize * arg.blockSize_) / 16;
|
||||
index_t num_threads_n = 16;
|
||||
|
||||
auto in_grid_2d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MN(arg.lengths_,
|
||||
arg.inStridesArray_[I.value],
|
||||
gridSize,
|
||||
arg.blockSize_,
|
||||
num_threads_m,
|
||||
num_threads_n);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_2d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MN(arg.lengths_,
|
||||
arg.outStridesArray_[I.value],
|
||||
gridSize,
|
||||
arg.blockSize_,
|
||||
num_threads_m,
|
||||
num_threads_n);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto kernel = kernel_elementwise_2d<GridwiseElementwise,
|
||||
InGrid2dDescTuple,
|
||||
OutGrid2dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
in_grid_2d_desc_tuple,
|
||||
out_grid_2d_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_,
|
||||
num_threads_m,
|
||||
num_threads_n);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t vectorDim) {
|
||||
if(strides[vectorDim] == 1 &&
|
||||
(lengths[vectorDim] % scalarPerVector == 0 ||
|
||||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim])
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->inStridesArray_[I.value],
|
||||
InScalarPerVectorSeq::At(I),
|
||||
NumDim_m - 1))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->outStridesArray_[I.value],
|
||||
OutScalarPerVectorSeq::At(I),
|
||||
NumDim - 1))
|
||||
valid = false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,371 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim_m, // choose how to set dims
|
||||
index_t NumDim_n,
|
||||
index_t NumDim_k,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
NumDim_m + NumDim_n + NumDim_k>
|
||||
{
|
||||
static constexpr index_t NumDim = NumDim_m + NumDim_n + NumDim_k;
|
||||
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
}
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_MNK>
|
||||
static auto PadDescriptor_MNK(Desc_MNK desc_mnk,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n,
|
||||
index_t num_threads_k)
|
||||
{
|
||||
std::ignore = blockSize;
|
||||
std::ignore = gridSize;
|
||||
|
||||
const auto m = desc_mnk.GetLength(I0);
|
||||
const auto n = desc_mnk.GetLength(I1);
|
||||
const auto k = desc_mnk.GetLength(I2);
|
||||
|
||||
const index_t loop_step_m = num_threads_m * MPerThread;
|
||||
const index_t loop_step_n = num_threads_n * NPerThread;
|
||||
const index_t loop_step_k = num_threads_k * KPerThread;
|
||||
|
||||
const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m;
|
||||
const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n;
|
||||
const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k;
|
||||
|
||||
const auto desc_mnk_pad =
|
||||
transform_tensor_descriptor(desc_mnk,
|
||||
make_tuple(make_right_pad_transform(m, pad_m),
|
||||
make_right_pad_transform(n, pad_n),
|
||||
make_right_pad_transform(k, pad_k)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
return desc_mnk_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_MNK(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize,
|
||||
index_t num_threads_m,
|
||||
index_t num_threads_n,
|
||||
index_t num_threads_k)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type();
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDim_m, NumDim_m + NumDim_n, 1>::type();
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDim_m + NumDim_n, NumDim, 1>::type();
|
||||
|
||||
const auto mLengths = get_container_subset(tupleOfShape, mDimIds);
|
||||
const auto nLengths = get_container_subset(tupleOfShape, nDimIds);
|
||||
const auto kLengths = get_container_subset(tupleOfShape, kDimIds);
|
||||
|
||||
// merge nd to 3d desc - [s0 * s1 * ...]
|
||||
if constexpr(NumDim > 3)
|
||||
{
|
||||
const auto desc_mnk = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths),
|
||||
make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return PadDescriptor_MNK(
|
||||
desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_MNK(
|
||||
desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid3dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 3)
|
||||
{
|
||||
return MakeDescriptor_MNK({1, 1, 1}, {1, 1, 1}, 1, 1, 1, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_MNK({1}, {1}, 1, 1, 1, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
}
|
||||
|
||||
using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumOutput>{}));
|
||||
using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number<NumInput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_3D<InGrid3dDescTuple,
|
||||
OutGrid3dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
blockSize_(256)
|
||||
{
|
||||
static_assert(NumDim_m > 0, "");
|
||||
static_assert(NumDim_n > 0, "");
|
||||
static_assert(NumDim_k > 0, "");
|
||||
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_;
|
||||
index_t num_threads_m = gridSize / (16 * 16);
|
||||
index_t num_threads_n = 16;
|
||||
index_t num_threads_k = 16;
|
||||
|
||||
auto in_grid_3d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MNK(arg.lengths_,
|
||||
arg.inStridesArray_[I.value],
|
||||
gridSize,
|
||||
arg.blockSize_,
|
||||
num_threads_m,
|
||||
num_threads_n,
|
||||
num_threads_k);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_3d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_MNK(arg.lengths_,
|
||||
arg.outStridesArray_[I.value],
|
||||
gridSize,
|
||||
arg.blockSize_,
|
||||
num_threads_m,
|
||||
num_threads_n,
|
||||
num_threads_k);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto kernel = kernel_elementwise_3d<GridwiseElementwise,
|
||||
InGrid3dDescTuple,
|
||||
OutGrid3dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
in_grid_3d_desc_tuple,
|
||||
out_grid_3d_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_,
|
||||
num_threads_m,
|
||||
num_threads_n,
|
||||
num_threads_k);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t vectorDim) {
|
||||
if(strides[vectorDim] == 1 &&
|
||||
(lengths[vectorDim] % scalarPerVector == 0 ||
|
||||
lengths[vectorDim] % scalarPerVector == lengths[vectorDim]))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if(strides[vectorDim] >= scalarPerVector)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->inStridesArray_[I.value],
|
||||
InScalarPerVectorSeq::At(I),
|
||||
NumDim_m - 1);
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
valid = valid && IsScalarPerVectorValid(pArg->lengths_,
|
||||
pArg->outStridesArray_[I.value],
|
||||
OutScalarPerVectorSeq::At(I),
|
||||
NumDim - 1);
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
@@ -190,7 +190,8 @@ struct DeviceElementwiseImpl
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
false>;
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
@@ -206,7 +207,8 @@ struct DeviceElementwiseImpl
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
true>;
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim, // The max dim of input tensors
|
||||
// the tensors descs have to be aligned, such that
|
||||
// the innermost dim is the contiguous one.
|
||||
index_t MPerThread, // How many elements per thread to read
|
||||
typename InScalarPerVectorSeq, // Scalar per vec for each Input
|
||||
typename OutScalarPerVectorSeq> // Scalar per vec for each Output
|
||||
struct DeviceElementwiseImpl
|
||||
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
template <index_t TupleSize>
|
||||
static auto GenerateInOutGrid1dDescTuple(Number<TupleSize>)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(NumDim > 1)
|
||||
{
|
||||
return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeDescriptor_M({1}, {1}, 1, 1);
|
||||
};
|
||||
},
|
||||
Number<TupleSize>{});
|
||||
};
|
||||
|
||||
using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumInput>{}));
|
||||
using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number<NumOutput>{}));
|
||||
|
||||
using GridwiseElementwise = GridwiseElementwise_1D<InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation,
|
||||
MPerThread,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op),
|
||||
blockSize_(256)
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
index_t blockSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
index_t gridSize = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
auto in_grid_1d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_1d_desc_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return MakeDescriptor_M(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
|
||||
InGrid1dDescTuple,
|
||||
OutGrid1dDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gridSize),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
in_grid_1d_desc_tuple,
|
||||
out_grid_1d_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector) {
|
||||
if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
|
||||
return true;
|
||||
|
||||
if(strides.back() != 1 && scalarPerVector == 1)
|
||||
return true;
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
bool valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
|
||||
valid = valid && false;
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
if(!IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
|
||||
valid = valid && false;
|
||||
});
|
||||
|
||||
return valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseImpl<" ;
|
||||
str << "NumDim_" << NumDim << ",";
|
||||
str << "MPerThread_" << MPerThread << ",";
|
||||
|
||||
str << "InScalarPerVector";
|
||||
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
|
||||
str << ",";
|
||||
str << "OutScalarPerVector";
|
||||
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
|
||||
|
||||
str << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -19,6 +19,10 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \note This structure is deprecated (left for backwards compatibility). Please use
|
||||
* DeviceElementwiseImpl from device_elementwise_dynamic_vector_dims_impl.hpp.
|
||||
*/
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -522,7 +522,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
Sequence<0, 1>,
|
||||
decltype(MakeElementwiseInputSequence()),
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
true>;
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
|
||||
@@ -814,8 +814,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
@@ -252,7 +252,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
Sequence<0, 1>,
|
||||
ElementwiseInputSequence,
|
||||
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
true>;
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
// Block2CTileMap configuration parameter.
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,10 +8,13 @@
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
@@ -36,9 +39,10 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step)
|
||||
static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step)
|
||||
{
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
@@ -56,7 +60,18 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
return PadDescriptor_M_1d(desc_m, loop_step);
|
||||
}
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto ExpendDescFirstDim(Desc_M desc_m)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
desc_m,
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
|
||||
using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
|
||||
using InOutGrid2dDesc = decltype(ExpendDescFirstDim(InOutGrid1dDesc{}));
|
||||
|
||||
using GridwisePutElementSet = GridwisePutElement_1D<InOutGrid1dDesc,
|
||||
DOutDataType,
|
||||
@@ -74,14 +89,30 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
InOutVectorSize>;
|
||||
|
||||
using GridwiseCasting = GridwiseElementwise_1D<Tuple<InOutGrid1dDesc>,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
UnaryConvert,
|
||||
InOutVectorSize,
|
||||
Sequence<InOutVectorSize>,
|
||||
Sequence<InOutVectorSize>>;
|
||||
static constexpr index_t BlockSize = 256;
|
||||
static constexpr index_t MPerThread = 1;
|
||||
static constexpr index_t NPerThread = InOutVectorSize;
|
||||
static constexpr index_t MPerBlock = 1;
|
||||
static constexpr index_t NPerBlock = BlockSize * NPerThread;
|
||||
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseCasting = GridwiseElementwise<Tuple<InOutGrid2dDesc>,
|
||||
Tuple<InOutGrid2dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
Block2TileMap,
|
||||
UnaryConvert,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
Sequence<0, 1>,
|
||||
Sequence<InOutVectorSize>,
|
||||
Sequence<InOutVectorSize>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -98,7 +129,7 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
p_din_{p_din},
|
||||
dout_length_raw_{dout_length},
|
||||
din_length_raw_{din_length},
|
||||
blockSize_{256},
|
||||
blockSize_{BlockSize},
|
||||
windowOverlap_{false}
|
||||
{
|
||||
for(size_t i = 0; i < window_lengths.size(); ++i)
|
||||
@@ -195,12 +226,13 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
PassThrough>;
|
||||
|
||||
const auto cast_kernel =
|
||||
kernel_elementwise_1d<GridwiseCasting,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<InOutGrid1dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
UnaryConvert>;
|
||||
kernel_elementwise<GridwiseCasting,
|
||||
Tuple<InOutGrid2dDesc>,
|
||||
Tuple<InOutGrid2dDesc>,
|
||||
Tuple<const DInDataType_AutomicAddPreCast*>,
|
||||
Tuple<DInDataType*>,
|
||||
Block2TileMap,
|
||||
UnaryConvert>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -214,16 +246,25 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
|
||||
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
|
||||
PassThrough{});
|
||||
|
||||
InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc);
|
||||
const index_t M = din_grid_desc_2d.GetLength(I0);
|
||||
const index_t N = din_grid_desc_2d.GetLength(I1);
|
||||
const auto block_2_tile_map = Block2TileMap(M, N);
|
||||
const auto cast_kernel_grid_size =
|
||||
block_2_tile_map.CalculateGridSize(din_grid_desc_2d);
|
||||
|
||||
elapsed_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
cast_kernel,
|
||||
dim3(gridSize),
|
||||
dim3(cast_kernel_grid_size),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
ck::make_tuple(din_grid_desc),
|
||||
ck::make_tuple(din_grid_desc),
|
||||
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
|
||||
arg.p_din_,
|
||||
ck::make_tuple(din_grid_desc_2d),
|
||||
ck::make_tuple(din_grid_desc_2d),
|
||||
ck::make_tuple(
|
||||
static_cast<const DInDataType_AutomicAddPreCast*>(arg.p_workspace_)),
|
||||
ck::make_tuple(arg.p_din_),
|
||||
block_2_tile_map,
|
||||
UnaryConvert{});
|
||||
|
||||
return elapsed_time;
|
||||
|
||||
Reference in New Issue
Block a user