mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add elementwise with dynamic vector dim (#1198)
* Add elementwise with dynamic vector dim
* Reduce number of instaces
* Fixes
* Fixes
[ROCm/composable_kernel commit: 9c052804a7]
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -20,15 +20,20 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // Elementwise op
|
||||
4, // NumDim
|
||||
8, // MPerThread
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // Elementwise
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -21,26 +21,23 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // ElementwiseOp
|
||||
UnaryOp, // UnaryOp
|
||||
Scale, // Scalar
|
||||
4, // NumDim
|
||||
8, // MPerThread
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
FunctorA functor_a,
|
||||
FunctorB functor_b,
|
||||
float scale)
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
@@ -51,11 +48,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
ADataType tmp_val;
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor_b(tmp_val, a_val);
|
||||
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
|
||||
scale * tmp_val);
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,14 +98,8 @@ int main()
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
PassThrough{},
|
||||
UnaryOp{},
|
||||
Scale{scale});
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -143,7 +131,7 @@ int main()
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -20,36 +20,31 @@ using F32 = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // ElementwiseOp
|
||||
UnaryOp, // UnaryOp
|
||||
Scale, // Scalar
|
||||
4, // NumDim
|
||||
8, // MPerThread
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
FunctorA functor_a,
|
||||
FunctorB functor_b,
|
||||
float scale)
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
ADataType tmp_val;
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor_b(tmp_val, a_val);
|
||||
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,14 +81,8 @@ int main()
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
PassThrough{},
|
||||
UnaryOp{},
|
||||
Scale{scale});
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -125,7 +114,7 @@ int main()
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -20,26 +20,23 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // ElementwiseOp
|
||||
UnaryOp, // UnaryOp
|
||||
Scale, // Scalar
|
||||
4, // NumDim
|
||||
1, // MPerThread
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<1>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
FunctorA functor_a,
|
||||
FunctorB functor_b,
|
||||
float scale)
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
std::size_t N = A_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = A_nchw.mDesc.GetLengths()[1];
|
||||
@@ -50,11 +47,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
for(std::size_t c = 0; c < C; ++c)
|
||||
for(std::size_t n = 0; n < N; ++n)
|
||||
{
|
||||
ADataType tmp_val;
|
||||
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
|
||||
functor_b(tmp_val, a_val);
|
||||
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
|
||||
scale * tmp_val);
|
||||
functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,14 +98,8 @@ int main()
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
PassThrough{},
|
||||
UnaryOp{},
|
||||
Scale{scale});
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -143,7 +131,7 @@ int main()
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -20,36 +20,31 @@ using F32 = float;
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
PassThrough, // ElementwiseOp
|
||||
UnaryOp, // UnaryOp
|
||||
Scale, // Scalar
|
||||
4, // NumDim
|
||||
8, // MPerThread
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<1>>; // OutScalarPerVectorSeq
|
||||
using UnaryOp = ck::tensor_operation::element_wise::Scale;
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<ADataType>, // InDataTypeTuple
|
||||
ck::Tuple<BDataType>, // OutDataTypeTuple
|
||||
UnaryOp, // UnaryOp
|
||||
4, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8>>; // OutScalarPerVectorSeq
|
||||
|
||||
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc,
|
||||
const HostTensorA& A_nchw,
|
||||
FunctorA functor_a,
|
||||
FunctorB functor_b,
|
||||
float scale)
|
||||
template <typename HostTensorA, typename HostTensorB, typename Functor>
|
||||
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
|
||||
{
|
||||
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
|
||||
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
|
||||
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
|
||||
{
|
||||
ADataType tmp_val;
|
||||
auto a_val = A_nchw(n, c, h, w);
|
||||
functor_b(tmp_val, a_val);
|
||||
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
|
||||
functor(B_nhwc(n, h, w, c), a_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,14 +81,8 @@ int main()
|
||||
ck::ranges::copy(nchw, ab_lengths.begin());
|
||||
|
||||
auto broadcastPermute = DeviceElementwisePermuteInstance{};
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
|
||||
{a_strides},
|
||||
{b_strides},
|
||||
input,
|
||||
output,
|
||||
PassThrough{},
|
||||
UnaryOp{},
|
||||
Scale{scale});
|
||||
auto argument = broadcastPermute.MakeArgumentPointer(
|
||||
ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale});
|
||||
|
||||
if(!broadcastPermute.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
@@ -125,7 +114,7 @@ int main()
|
||||
{
|
||||
b_device_buf.FromDevice(b.mData.data());
|
||||
Tensor<BDataType> host_b(nhwc);
|
||||
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
|
||||
host_elementwise4D(host_b, a, UnaryOp{scale});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise data transfer
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcsScalarPerVector, // Sequence
|
||||
typename DstsScalarPerVector, // Sequence
|
||||
typename SrcsScalarStrideInVector, // Sequence
|
||||
typename DstsScalarStrideInVector, // Sequence
|
||||
typename ThreadTransferSrcsResetCoordinateAfterRun, // Sequence
|
||||
typename ThreadTransferDstsResetCoordinateAfterRun, // Sequence
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r2
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
remove_reference_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_descs,
|
||||
StaticallyIndexedArray<Index, nSrc>{},
|
||||
dst_descs,
|
||||
StaticallyIndexedArray<Index, nDst>{},
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
static_assert(nDim ==
|
||||
remove_cvref_t<tuple_element_t<src_i, SrcDescs>>::GetNumOfDimension(),
|
||||
"wrong! nDim not consistent");
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
static_assert(nDim ==
|
||||
remove_cvref_t<tuple_element_t<dst_i, DstDescs>>::GetNumOfDimension(),
|
||||
"wrong! nDim not consistent");
|
||||
});
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
const auto src_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nSrc>{});
|
||||
|
||||
const auto dst_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nDst>{});
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
|
||||
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffer& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffer& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r2<decltype(thread_slice_lengths),
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
SrcDatas,
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcsScalarPerVector,
|
||||
DstsScalarPerVector,
|
||||
SrcsScalarStrideInVector,
|
||||
DstsScalarStrideInVector,
|
||||
ThreadTransferSrcsResetCoordinateAfterRun,
|
||||
ThreadTransferDstsResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim,
|
||||
index_t BlockSize,
|
||||
index_t M0PerBlock,
|
||||
index_t M1PerBlock,
|
||||
index_t M0PerThread,
|
||||
index_t M1PerThread,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwiseImpl
|
||||
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
|
||||
{
|
||||
index_t most_continous_dim = NumDim - 1;
|
||||
index_t most_continous_dim_stride = strides[most_continous_dim];
|
||||
for(index_t dim = 0; dim < NumDim; dim++)
|
||||
{
|
||||
if(strides[dim] < most_continous_dim_stride)
|
||||
{
|
||||
most_continous_dim_stride = strides[dim];
|
||||
most_continous_dim = dim;
|
||||
}
|
||||
}
|
||||
return most_continous_dim;
|
||||
}
|
||||
|
||||
template <typename InOutDescriptor>
|
||||
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
|
||||
{
|
||||
const auto M0 = desc.GetLength(I0);
|
||||
const auto M1 = desc.GetLength(I1);
|
||||
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
|
||||
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return padded_desc;
|
||||
}
|
||||
|
||||
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
|
||||
const index_t M0_dim,
|
||||
const index_t M1_dim)
|
||||
{
|
||||
// Generate batch dims, they will be merged to M0
|
||||
// Add one more dim than needed in case that M0 is equal to M1
|
||||
// If M0 is equal to M1, then will be one more batch dim
|
||||
std::array<index_t, NumDim - 1> batch_dims;
|
||||
index_t batch_dim = 0;
|
||||
for(index_t i = 0; i < NumDim; i++)
|
||||
{
|
||||
if(i != M0_dim && i != M1_dim)
|
||||
{
|
||||
batch_dims[batch_dim] = lengths[i];
|
||||
batch_dim++;
|
||||
}
|
||||
}
|
||||
// Add dummy dim if M0_dim is not equal to M1_dim
|
||||
if(M0_dim != M1_dim && NumDim >= 2)
|
||||
batch_dims[NumDim - 2] = 1;
|
||||
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& in_strides,
|
||||
const std::array<index_t, NumDim>& out_strides,
|
||||
const std::array<index_t, NumDim>& desc_strides)
|
||||
{
|
||||
const auto M0_dim = GetLowestStrideDim(out_strides);
|
||||
const auto M1_dim = GetLowestStrideDim(in_strides);
|
||||
|
||||
// If M0_dim is equal to M1_dim, then make M0_dim dummy
|
||||
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
|
||||
const auto M1 = lengths[M1_dim];
|
||||
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
|
||||
const auto M1_stride = desc_strides[M1_dim];
|
||||
|
||||
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
|
||||
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
|
||||
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
|
||||
// Merged batch dims with M0
|
||||
const auto transforms =
|
||||
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
|
||||
make_pass_through_transform(M1));
|
||||
using BatchElemsSequence =
|
||||
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
|
||||
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
|
||||
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
|
||||
// desc: (merged_dims + M0, M1)
|
||||
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
return PadInputOutputDescriptor(merged_desc);
|
||||
}
|
||||
|
||||
template <index_t NumTensors>
|
||||
static auto GenerateInOutGridDescTuple()
|
||||
{
|
||||
std::array<index_t, NumDim> ones;
|
||||
for(index_t d = 0; d < NumDim; d++)
|
||||
{
|
||||
ones[d] = 1;
|
||||
}
|
||||
|
||||
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
|
||||
Number<NumTensors>{});
|
||||
};
|
||||
|
||||
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
|
||||
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
|
||||
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
|
||||
|
||||
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
false>;
|
||||
|
||||
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
true>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op)
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto in_grid_desc_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
// Use Strides from first tensor to assert that M0 dim and
|
||||
// M1 dim are the same for each tensor.
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.inStridesArray_[src_i]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_desc_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.outStridesArray_[dst_i]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
|
||||
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
|
||||
|
||||
const auto block_2_tile_map = Block2TileMap(M0, M1);
|
||||
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
|
||||
|
||||
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
|
||||
GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
const auto kernel = in_out_same_vector_dim
|
||||
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>
|
||||
: kernel_elementwise<GridwiseElementwiseOp,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
block_2_tile_map,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
|
||||
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t M_dim) {
|
||||
if(scalarPerVector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
|
||||
});
|
||||
|
||||
return is_valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseImpl<";
|
||||
str << NumDim << ", ";
|
||||
str << BlockSize << ", ";
|
||||
str << M0PerBlock << ", ";
|
||||
str << M1PerBlock << ", ";
|
||||
str << M0PerThread << ", ";
|
||||
str << M1PerThread << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMap,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
|
||||
const OutGridDescTuple out_grid_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const Block2TileMap block_2_tile_map,
|
||||
const ElementwiseOperation elementwise_op)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
p_in_global_tuple,
|
||||
p_out_global_tuple,
|
||||
block_2_tile_map,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMap,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t M0PerBlock,
|
||||
index_t M1PerBlock,
|
||||
index_t M0PerThread,
|
||||
index_t M1PerThread,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq,
|
||||
bool InOutSameVectorDim>
|
||||
struct GridwiseElementwise
|
||||
{
|
||||
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
|
||||
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size() &&
|
||||
NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
__device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
|
||||
const OutGridDescTuple& out_grid_desc_tuple,
|
||||
const InDataTypePointerTuple& p_in_global_tuple,
|
||||
const OutDataTypePointerTuple& p_out_global_tuple,
|
||||
const Block2TileMap& block_2_tile_map,
|
||||
const ElementwiseOperation& elementwise_op)
|
||||
{
|
||||
|
||||
constexpr auto src_datas = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return DataType{};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
constexpr auto dst_datas = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return DataType{};
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto in_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t m0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
const index_t m1_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
|
||||
const auto thread_grid_offset =
|
||||
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
// If src and dst have same vector dim, then:
|
||||
// M0 dim - for src and dst vector load/store
|
||||
// else:
|
||||
// M0 dim - for dst vector load
|
||||
// M1 dim - for src vector store
|
||||
using SrcDimAccessOrder = Sequence<0, 1>;
|
||||
using DstDimAccessOrder =
|
||||
std::conditional_t<InOutSameVectorDim, Sequence<0, 1>, Sequence<1, 0>>;
|
||||
using SrcVectorDim = Number<1>;
|
||||
using DstVectorDim = std::conditional_t<InOutSameVectorDim, Number<1>, Number<0>>;
|
||||
|
||||
using ThreadClusterLengths =
|
||||
Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
|
||||
|
||||
auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
|
||||
ThisThreadBlock,
|
||||
ElementwiseOperation,
|
||||
uniform_sequence_gen_t<NumOutput, static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
|
||||
Sequence<M0PerBlock, M1PerBlock>,
|
||||
ThreadClusterLengths,
|
||||
ThreadClusterArrangeOrder,
|
||||
decltype(src_datas),
|
||||
decltype(dst_datas),
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim{},
|
||||
DstVectorDim{},
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
uniform_sequence_gen_t<NumInput, 1>,
|
||||
uniform_sequence_gen_t<NumOutput, 1>,
|
||||
uniform_sequence_gen_t<NumInput, false>,
|
||||
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
out_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
elementwise_op};
|
||||
global_to_global_transfer.Run(
|
||||
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,804 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcsScalarPerVector, // Sequence
|
||||
typename DstsScalarPerVector, // Sequence
|
||||
typename SrcsScalarStrideInVector, // Sequence
|
||||
typename DstsScalarStrideInVector, // Sequence
|
||||
typename SrcsResetCoordinateAfterRun, // control whether to move back src coordinate after
|
||||
// each RunRead(), will be fused with
|
||||
// MoveSrcSliceWindow to save addr computation
|
||||
typename DstsResetCoordinateAfterRun, // control whether to move back dst coordinate after
|
||||
// each RunWrite(), will be fused with
|
||||
// MoveDstSliceWindow to save addr computation
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
// return a tuple of coordiantes for a tuple of tensor
|
||||
template <typename Descs,
|
||||
typename Indices,
|
||||
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
|
||||
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
|
||||
{
|
||||
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
|
||||
Number<Descs::Size()>{});
|
||||
}
|
||||
|
||||
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
|
||||
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
|
||||
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
|
||||
element_op_(element_op)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
|
||||
const Indices& src_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
src_coords_(src_i) =
|
||||
make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
|
||||
const Indices& dst_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
dst_coords_(dst_i) =
|
||||
make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim,
|
||||
SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
constexpr auto src_access_lengths_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return SliceLengths{} / src_scalar_per_access_tuple.At(src_i);
|
||||
static_assert(
|
||||
SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector");
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i),
|
||||
src_dim_access_order);
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) =
|
||||
(i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value)
|
||||
? -src_scalar_per_access_tuple.At(src_i)[i]
|
||||
: 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
static_ford<remove_cvref_t<decltype(ordered_src_access_lengths_tuple.At(src_i))>>{}(
|
||||
[&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths_tuple[j] +
|
||||
ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths_tuple.At(src_i)[i] -
|
||||
1 - ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access_tuple.At(src_i);
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq =
|
||||
generate_sequence_v2([&](auto i) { return Number<src_data_idx[i]>{}; },
|
||||
Number<src_data_idx.Size()>{});
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_descs.At(src_i), src_coords_.At(src_i));
|
||||
|
||||
using src_vector_type = vector_type_maker_t<tuple_element_t<src_i, SrcDatas>,
|
||||
SrcsScalarPerVector::At(src_i)>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_bufs.At(src_i).template Get<src_vector_t>(
|
||||
src_coords_.At(src_i).GetOffset(), is_src_valid)};
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.At(src_i)
|
||||
.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq,
|
||||
src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] <
|
||||
ordered_src_access_lengths_tuple.At(src_i)[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] ==
|
||||
ordered_src_access_lengths_tuple.At(src_i)[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_descs.At(src_i),
|
||||
src_coords_.At(src_i),
|
||||
src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_descs.At(src_i),
|
||||
src_coords_.At(src_i),
|
||||
src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcsResetCoordinateAfterRun::At(src_i))
|
||||
{
|
||||
const auto src_reset_step = make_tensor_coordinate_step(
|
||||
src_descs.At(src_i), GetSrcCoordinateResetStep<src_i>());
|
||||
|
||||
move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t ThreadScratchId>
|
||||
__device__ void
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
// TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
// (it requires to add Elementwise support in transpose_vectors)
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
const auto src_data_refs = generate_tie(
|
||||
[&](auto src_i) -> const auto& {
|
||||
return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
auto dst_data_refs = generate_tie(
|
||||
[&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); },
|
||||
Number<nDst>{});
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim,
|
||||
DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
constexpr auto dst_access_lengths_tuple = generate_tuple(
|
||||
[&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); },
|
||||
Number<nDst>{});
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i),
|
||||
dst_dim_access_order);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) =
|
||||
(i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value)
|
||||
? -dst_scalar_per_access_tuple.At(dst_i)[i]
|
||||
: 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
static_ford<remove_cvref_t<decltype(ordered_dst_access_lengths_tuple.At(dst_i))>>{}(
|
||||
[&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] +
|
||||
ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths_tuple.At(dst_i)[i] -
|
||||
1 - ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access_tuple.At(dst_i);
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq =
|
||||
generate_sequence_v2([&](auto i) { return Number<dst_data_idx[i]>{}; },
|
||||
Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
dst_descs.At(dst_i), dst_coords_.At(dst_i));
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<tuple_element_t<dst_i, DstDatas>,
|
||||
DstsScalarPerVector::At(dst_i)>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ into dst_vector_container
|
||||
auto dst_vector_container = dst_vector_type{
|
||||
dst_thread_scratch_tuple_.At(dst_i).template GetAsType<dst_vector_t>(
|
||||
dst_data_idx_seq)};
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(dst_i.value));
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_bufs.At(dst_i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coords_.At(dst_i).GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] <
|
||||
ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] ==
|
||||
ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_descs.At(dst_i),
|
||||
dst_coords_.At(dst_i),
|
||||
dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_descs.At(dst_i),
|
||||
dst_coords_.At(dst_i),
|
||||
dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
if constexpr(DstsResetCoordinateAfterRun::At(dst_i))
|
||||
{
|
||||
const auto dst_reset_step = make_tensor_coordinate_step(
|
||||
dst_descs.At(dst_i), GetDstCoordinateResetStep<dst_i>());
|
||||
|
||||
move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t src_i>
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
template <index_t dst_i>
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access.At(dst_i);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcsResetCoordinateAfterRun::At(src_i)
|
||||
? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep<src_i>();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step);
|
||||
});
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstsResetCoordinateAfterRun::At(dst_i)
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep<dst_i>();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step);
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t src_i>
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length =
|
||||
container_push_back(sequence_to_tuple_of_number(src_access_lengths),
|
||||
Number<SrcsScalarPerVector::At(src_i)>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
template <index_t dst_i>
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length =
|
||||
container_push_back(sequence_to_tuple_of_number(dst_access_lengths),
|
||||
Number<DstsScalarPerVector::At(dst_i)>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeSrcThreadScratchTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto src_i) {
|
||||
constexpr auto src_thread_scratch_desc =
|
||||
decltype(GetSrcThreadScratchDescriptor<src_i>()){};
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
tuple_element_t<src_i, SrcDatas>,
|
||||
SrcsScalarPerVector::At(src_i),
|
||||
decltype(src_thread_scratch_desc),
|
||||
true>;
|
||||
return SrcThreadScratch{};
|
||||
},
|
||||
Number<nSrc>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeDstThreadScratchTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
constexpr auto dst_thread_scratch_desc =
|
||||
decltype(GetDstThreadScratchDescriptor<dst_i>()){};
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
tuple_element_t<dst_i, DstDatas>,
|
||||
DstsScalarPerVector::At(dst_i),
|
||||
decltype(dst_thread_scratch_desc),
|
||||
true>;
|
||||
return DstThreadScratch{};
|
||||
},
|
||||
Number<nDst>{});
|
||||
}
|
||||
|
||||
private:
|
||||
using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple());
|
||||
using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple());
|
||||
|
||||
StaticallyIndexedArray<SrcThreadScratchTuple, NumThreadScratch> src_thread_scratch_tuple_;
|
||||
|
||||
DstThreadScratchTuple dst_thread_scratch_tuple_;
|
||||
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
@@ -19,125 +19,67 @@ namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_permute_scale_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
1>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 1>>>&);
|
||||
|
||||
void add_device_permute_scale_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
2>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 2>>>&);
|
||||
|
||||
void add_device_permute_scale_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
3>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 3>>>&);
|
||||
|
||||
void add_device_permute_scale_4d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
4>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 4>>>&);
|
||||
|
||||
void add_device_permute_scale_5d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
5>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 5>>>&);
|
||||
|
||||
void add_device_permute_scale_6d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>,
|
||||
ck::Tuple<F16>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
6>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, element_wise::Scale, 6>>>&);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_permute_scale_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
1>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 1>>>&);
|
||||
|
||||
void add_device_permute_scale_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
2>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 2>>>&);
|
||||
|
||||
void add_device_permute_scale_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
3>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 3>>>&);
|
||||
|
||||
void add_device_permute_scale_4d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
4>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 4>>>&);
|
||||
|
||||
void add_device_permute_scale_5d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
5>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 5>>>&);
|
||||
|
||||
void add_device_permute_scale_6d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>,
|
||||
ck::Tuple<F32>,
|
||||
PassThrough,
|
||||
element_wise::UnarySquare,
|
||||
Scale,
|
||||
6>>>&);
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, element_wise::Scale, 6>>>&);
|
||||
#endif
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
NumDim>>
|
||||
ck::tensor_operation::device::
|
||||
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>
|
||||
{
|
||||
using DeviceOp = DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
NumDim>;
|
||||
using DeviceOp =
|
||||
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -13,26 +13,175 @@ namespace instance {
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
// clang-format off
|
||||
template <index_t NDims>
|
||||
template <index_t NDims,
|
||||
typename ElementwiseOp>
|
||||
using device_permute_scale_f16_instances =
|
||||
std::tuple <
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>>
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
|
||||
#if 0
|
||||
// Disabled instances to improve compilation time
|
||||
// They listed here to show other possible combinations of parameters
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
#endif
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F16>, ck::Tuple<F16>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
|
||||
|
||||
>;
|
||||
|
||||
template <index_t NDims>
|
||||
template <index_t NDims,
|
||||
typename ElementwiseOp>
|
||||
using device_permute_scale_f32_instances = std::tuple<
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>>
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>,
|
||||
|
||||
#if 0
|
||||
// Disabled instances to improve compilation time
|
||||
// They listed here to show other possible combinations of parameters
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>,
|
||||
#endif
|
||||
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>,
|
||||
DeviceElementwiseImpl<ck::Tuple<F32>, ck::Tuple<F32>, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
add_instance_library(device_permute_scale_instance
|
||||
device_permute_scale_1d_instances.cpp
|
||||
device_permute_scale_2d_instances.cpp
|
||||
device_permute_scale_3d_instances.cpp
|
||||
device_permute_scale_4d_instances.cpp
|
||||
device_permute_scale_5d_instances.cpp
|
||||
device_permute_scale_6d_instances.cpp)
|
||||
device_permute_scale_1d_fp16_instances.cpp
|
||||
device_permute_scale_2d_fp16_instances.cpp
|
||||
device_permute_scale_3d_fp16_instances.cpp
|
||||
device_permute_scale_4d_fp16_instances.cpp
|
||||
device_permute_scale_5d_fp16_instances.cpp
|
||||
device_permute_scale_6d_fp16_instances.cpp
|
||||
device_permute_scale_1d_fp32_instances.cpp
|
||||
device_permute_scale_2d_fp32_instances.cpp
|
||||
device_permute_scale_3d_fp32_instances.cpp
|
||||
device_permute_scale_4d_fp32_instances.cpp
|
||||
device_permute_scale_5d_fp32_instances.cpp
|
||||
device_permute_scale_6d_fp32_instances.cpp)
|
||||
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 1>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<1>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 1>>>& instances)
|
||||
void add_device_permute_scale_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<1>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<1, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<1, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 2>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<2>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 2>>>& instances)
|
||||
void add_device_permute_scale_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 2>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<2>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<2, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 2>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<2, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 3>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<3>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 3>>>& instances)
|
||||
void add_device_permute_scale_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<3>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<3, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<3, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_4d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 4>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<4>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_4d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 4>>>& instances)
|
||||
void add_device_permute_scale_4d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 4>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<4>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<4, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_4d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 4>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<4, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_5d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 5>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<5>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_5d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 5>>>& instances)
|
||||
void add_device_permute_scale_5d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 5>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<5>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<5, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_5d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 5>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<5, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,18 +9,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_permute_scale_6d_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Pass, UnaryOp, Scale, 6>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<6>{});
|
||||
}
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_6d_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Pass, UnaryOp, Scale, 6>>>& instances)
|
||||
void add_device_permute_scale_6d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F16>, ck::Tuple<F16>, Scale, 6>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<6>{});
|
||||
add_device_operation_instances(instances, device_permute_scale_f16_instances<6, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Scale = element_wise::Scale;
|
||||
|
||||
void add_device_permute_scale_6d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceElementwise<ck::Tuple<F32>, ck::Tuple<F32>, Scale, 6>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_permute_scale_f32_instances<6, Scale>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,9 +8,9 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
|
||||
|
||||
@@ -21,23 +21,12 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename ScaleElementOp>
|
||||
template <typename HostTensorA, typename HostTensorB, typename ElementOp>
|
||||
void reference_permute_scale(HostTensorB& b_tensor,
|
||||
const HostTensorA& a_tensor,
|
||||
AElementOp a_tensor_op,
|
||||
BElementOp b_tensor_op,
|
||||
ScaleElementOp scale_op)
|
||||
ElementOp tensor_op)
|
||||
{
|
||||
b_tensor.ForEach([&](auto& self, auto idx) {
|
||||
auto tmp_val = a_tensor(idx);
|
||||
b_tensor_op(tmp_val, tmp_val);
|
||||
scale_op(tmp_val, tmp_val);
|
||||
a_tensor_op(self(idx), tmp_val);
|
||||
});
|
||||
b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); });
|
||||
}
|
||||
|
||||
namespace profiler {
|
||||
@@ -54,9 +43,7 @@ bool profile_permute_scale_impl(int do_verification,
|
||||
bool pass = true;
|
||||
bool instance_found = false;
|
||||
|
||||
using ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
float scale = 2.f;
|
||||
|
||||
Tensor<ADataType> a(lengths_vector, input_strides_vector);
|
||||
@@ -80,12 +67,8 @@ bool profile_permute_scale_impl(int do_verification,
|
||||
|
||||
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ADataType>,
|
||||
ck::Tuple<BDataType>,
|
||||
ElementOp,
|
||||
UnaryOp,
|
||||
Scale,
|
||||
NumDim>;
|
||||
using DeviceOp = ck::tensor_operation::device::
|
||||
DeviceElementwise<ck::Tuple<ADataType>, ck::Tuple<BDataType>, ElementOp, NumDim>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -100,7 +83,7 @@ bool profile_permute_scale_impl(int do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
reference_permute_scale(host_b, a, ElementOp{}, UnaryOp{}, Scale{scale});
|
||||
reference_permute_scale(host_b, a, ElementOp{scale});
|
||||
}
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
|
||||
@@ -113,14 +96,8 @@ bool profile_permute_scale_impl(int do_verification,
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(lengths,
|
||||
{input_strides},
|
||||
{output_strides},
|
||||
input,
|
||||
output,
|
||||
ElementOp{},
|
||||
UnaryOp{},
|
||||
Scale{scale});
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
lengths, {input_strides}, {output_strides}, input, output, ElementOp{scale});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -141,6 +118,7 @@ bool profile_permute_scale_impl(int do_verification,
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "host_b: ", host_b.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,20 @@ static void print_helper_msg()
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void init_strides(const std::vector<ck::index_t>& lengths,
|
||||
const std::vector<ck::index_t>& dims_order,
|
||||
std::vector<ck::index_t>& strides)
|
||||
{
|
||||
|
||||
ck::index_t stride = 1;
|
||||
for(ck::index_t d = lengths.size() - 1; d >= 0; d--)
|
||||
{
|
||||
ck::index_t dim = dims_order[d];
|
||||
strides[dim] = stride;
|
||||
stride *= lengths[dim];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_permute_scale(int argc, char* argv[])
|
||||
@@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[])
|
||||
const int num_dims = dims_argc / 3;
|
||||
|
||||
std::vector<ck::index_t> lengths(num_dims);
|
||||
std::vector<ck::index_t> input_strides(num_dims);
|
||||
std::vector<ck::index_t> output_strides(num_dims);
|
||||
std::vector<ck::index_t> input_dims_order(num_dims);
|
||||
std::vector<ck::index_t> output_dims_order(num_dims);
|
||||
|
||||
for(int i = 0; i < num_dims; i++)
|
||||
{
|
||||
lengths[i] = std::stoi(argv[control_argc + i]);
|
||||
input_strides[i] = std::stoi(argv[control_argc + num_dims + i]);
|
||||
output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
|
||||
lengths[i] = std::stoi(argv[control_argc + i]);
|
||||
input_dims_order[i] = std::stoi(argv[control_argc + num_dims + i]);
|
||||
output_dims_order[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> input_strides(num_dims);
|
||||
std::vector<ck::index_t> output_strides(num_dims);
|
||||
init_strides(lengths, input_dims_order, input_strides);
|
||||
init_strides(lengths, output_dims_order, output_strides);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
|
||||
43
script/profile_permute_scale.sh
Executable file
43
script/profile_permute_scale.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
|
||||
## GPU visibility
|
||||
export HIP_VISIBLE_DEVICES=0
|
||||
DRIVER="../build/bin/ckProfiler"
|
||||
echo $DRIVER
|
||||
OP=$1
|
||||
DATATYPE=$2
|
||||
VERIFY=$3
|
||||
INIT=$4
|
||||
LOG=$5
|
||||
TIME=$6
|
||||
|
||||
|
||||
# 1D
|
||||
######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 67108864 0 0
|
||||
|
||||
# # 2D
|
||||
# ######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 0 1 1 0
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 1 0 0 1
|
||||
|
||||
# 3D
|
||||
######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 0 1 2 2 1 0
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 2 1 0 0 1 2
|
||||
|
||||
# 4D
|
||||
######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 0 1 2 3 3 2 1 0
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 3 2 1 0 0 1 2 3
|
||||
|
||||
# 5D
|
||||
######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 0 1 2 3 4 4 3 2 1 0
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 4 3 2 1 0 0 1 2 3 4
|
||||
|
||||
# 6D
|
||||
######## op datatype verify init log time dims in_strides_order out_strides_order
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 0 1 2 3 4 5 5 4 3 2 1 0
|
||||
$DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 5 4 3 2 1 0 0 1 2 3 4 5
|
||||
|
||||
@@ -52,40 +52,40 @@ TYPED_TEST_SUITE(TestPermute, KernelTypes);
|
||||
TYPED_TEST(TestPermute, Test1D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 1;
|
||||
this->template Run<NumDims>({8}, {1}, {2});
|
||||
this->template Run<NumDims>({8}, {2}, {1});
|
||||
this->template Run<NumDims>({16}, {1}, {1});
|
||||
this->template Run<NumDims>({16}, {1}, {2});
|
||||
this->template Run<NumDims>({1}, {1}, {1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test2D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 2;
|
||||
this->template Run<NumDims>({8, 4}, {4, 1}, {1, 8});
|
||||
this->template Run<NumDims>({8, 4}, {1, 8}, {4, 1});
|
||||
this->template Run<NumDims>({8, 16}, {16, 1}, {1, 8});
|
||||
this->template Run<NumDims>({8, 16}, {1, 8}, {16, 1});
|
||||
this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test3D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 3;
|
||||
this->template Run<NumDims>({2, 4, 4}, {16, 4, 1}, {1, 2, 8});
|
||||
this->template Run<NumDims>({2, 4, 4}, {1, 2, 8}, {16, 4, 1});
|
||||
this->template Run<NumDims>({8, 2, 8}, {16, 8, 1}, {1, 8, 16});
|
||||
this->template Run<NumDims>({8, 2, 8}, {1, 8, 16}, {16, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test4D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 4;
|
||||
this->template Run<NumDims>({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32});
|
||||
this->template Run<NumDims>({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1});
|
||||
this->template Run<NumDims>({8, 2, 3, 8}, {48, 24, 8, 1}, {1, 8, 16, 48});
|
||||
this->template Run<NumDims>({8, 2, 3, 8}, {1, 8, 16, 48}, {48, 24, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test5D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 5;
|
||||
this->template Run<NumDims>({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128});
|
||||
this->template Run<NumDims>({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1});
|
||||
this->template Run<NumDims>({8, 2, 3, 4, 8}, {192, 96, 32, 8, 1}, {1, 8, 16, 48, 192});
|
||||
this->template Run<NumDims>({8, 2, 3, 4, 8}, {1, 8, 16, 48, 192}, {192, 96, 32, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
@@ -93,8 +93,8 @@ TYPED_TEST(TestPermute, Test6D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 6;
|
||||
this->template Run<NumDims>(
|
||||
{2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512});
|
||||
{8, 2, 3, 4, 5, 8}, {960, 480, 160, 40, 8, 1}, {1, 8, 16, 48, 192, 960});
|
||||
this->template Run<NumDims>(
|
||||
{2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1});
|
||||
{8, 2, 3, 4, 5, 8}, {1, 8, 16, 48, 192, 960}, {960, 480, 160, 40, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user