mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
delete a typo error
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
` // SPDX-License-Identifier: MIT
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
@@ -7,65 +7,64 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Conv backward data multiple D:
|
||||
// input : output image A[G, N, K, Ho, Wo]
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
// output : input image E[G, N, C, Hi, Wi],
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
{
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// Conv backward data multiple D:
|
||||
// input : output image A[G, N, K, Ho, Wo]
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
// output : input image E[G, N, C, Hi, Wi],
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const ck::index_t split_k = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const ck::index_t split_k = 1) = 0;
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user