mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-27 00:14:35 +00:00
Reorganize project folders (#6)
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace conv_tensor_rearrange_op {
|
||||
|
||||
struct BaseConvTensorRearrangeOp
|
||||
{
|
||||
};
|
||||
|
||||
struct ImageToColumn : public BaseConvTensorRearrangeOp
|
||||
{
|
||||
static constexpr const char* name = "Image to Column";
|
||||
};
|
||||
|
||||
struct ColumnToImage : public BaseConvTensorRearrangeOp
|
||||
{
|
||||
static constexpr const char* name = "Column to Image";
|
||||
};
|
||||
|
||||
template <typename Op,
|
||||
typename std::enable_if<std::is_base_of<BaseConvTensorRearrangeOp, Op>::value,
|
||||
bool>::type = false>
|
||||
std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&)
|
||||
{
|
||||
os << Op::name;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace conv_tensor_rearrange_op
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionBackwardDataSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Stride1Pad0,
|
||||
};
|
||||
|
||||
inline std::string
|
||||
getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionBackwardDataSpecialization::Default: return "Default";
|
||||
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionBackwardWeightSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter1x1Pad0,
|
||||
OddC,
|
||||
};
|
||||
|
||||
inline std::string
|
||||
getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionBackwardWeightSpecialization::Default: return "Default";
|
||||
case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0:
|
||||
return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionBackwardWeightSpecialization::OddC: return "OddC";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionForwardSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
OddC,
|
||||
Filter3x3,
|
||||
};
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename DOutDataType,
|
||||
typename DInDataType,
|
||||
typename DOutLayout,
|
||||
typename DInLayout>
|
||||
struct DeviceAvgPoolBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
void* p_din,
|
||||
std::vector<ck::index_t> dout_n_k_wos_lengths,
|
||||
std::vector<ck::index_t> dout_n_k_wos_strides,
|
||||
std::vector<ck::index_t> din_n_k_wos_length,
|
||||
std::vector<ck::index_t> din_n_k_wos_strides,
|
||||
std::vector<ck::index_t> window_k_c_xs_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
115
include/ck/tensor_operation/gpu/device/device_base.hpp
Normal file
115
include/ck/tensor_operation/gpu/device/device_base.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <optional>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#define GET_OBJECT_NAME_IMLP \
|
||||
std::optional<std::string> GetObjectName() const override \
|
||||
{ \
|
||||
std::string str = __PRETTY_FUNCTION__; \
|
||||
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
|
||||
std::smatch match; \
|
||||
if(!std::regex_search(str, match, obj_name_expr)) \
|
||||
{ \
|
||||
return str; \
|
||||
} \
|
||||
return std::string(match[1]) + ';'; \
|
||||
}
|
||||
|
||||
#define GET_TEMPLATE_INFO_IMPL \
|
||||
std::optional<std::string> GetTemplateInfo() const override \
|
||||
{ \
|
||||
std::string str = __PRETTY_FUNCTION__; \
|
||||
static std::regex template_expr{"\\[(.*)\\]"}; \
|
||||
std::smatch match; \
|
||||
if(!std::regex_search(str, match, template_expr)) \
|
||||
{ \
|
||||
return std::nullopt; \
|
||||
} \
|
||||
return std::string(match[1]); \
|
||||
}
|
||||
|
||||
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
|
||||
#endif
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
struct BaseArgument
|
||||
{
|
||||
BaseArgument() = default;
|
||||
BaseArgument(const BaseArgument&) = default;
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
|
||||
void* p_workspace_ = nullptr;
|
||||
};
|
||||
|
||||
struct BaseInvoker
|
||||
{
|
||||
BaseInvoker() = default;
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker& operator=(const BaseInvoker&) = default;
|
||||
|
||||
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
|
||||
{
|
||||
return float{0};
|
||||
}
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct BaseOperator
|
||||
{
|
||||
BaseOperator() = default;
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
|
||||
|
||||
virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
|
||||
|
||||
virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
|
||||
|
||||
virtual std::string GetTypeIdHashCode() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
oss << std::hex << typeid(*this).hash_code();
|
||||
|
||||
return oss.str();
|
||||
};
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const
|
||||
{
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
}
|
||||
#endif
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceBatchedContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
|
||||
const std::vector<index_t>& e_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
110
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
Normal file
110
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
Normal file
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
ck::index_t Batch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemmV2BScale : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideScaleB,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
ck::index_t BatchStrideScaleB,
|
||||
const void* p_b_scale,
|
||||
ck::index_t Batch,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceBatchedGemmPtr = std::unique_ptr<DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct BatchedGemmEPermuteDesc
|
||||
{
|
||||
ck::index_t G0_, G1_, M_, N_;
|
||||
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceBatchedGemmEPermute : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemmGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceBatchedGemmMultiD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceBatchedGemmV2MultiD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation>
|
||||
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a0,
|
||||
const void* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const void* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
void* p_e1,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA0,
|
||||
ck::index_t StrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> StrideD0s,
|
||||
ck::index_t StrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> StrideD1s,
|
||||
ck::index_t StrideE1,
|
||||
ck::index_t BatchStrideA0,
|
||||
ck::index_t BatchStrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
|
||||
ck::index_t BatchStrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
|
||||
ck::index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool MaskOutUpperTriangle> // TODO: enum for mask type
|
||||
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
C1DEElementwiseOperation c1de_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormBwd : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> dyStrides,
|
||||
const std::array<index_t, Rank> dxStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
const void* p_scale,
|
||||
const void* p_savedMean,
|
||||
const void* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
void* p_dx,
|
||||
void* p_dscale,
|
||||
void* p_dbias) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double exponentialAverageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormInfer : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
const void* estimatedMean,
|
||||
const void* estimatedInvVariance,
|
||||
void* p_y) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormInferPtr = std::unique_ptr<DeviceBatchNormInfer<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
51
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
51
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceCGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) const = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceCGemmPtr = std::unique_ptr<
|
||||
DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A0[M0, M1, ... K0, K1, ...], ...
|
||||
// input : B0[N0, N1, ... K0, K1, ...], ...
|
||||
// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
|
||||
// output : E[M0, M1, ... N0, N1, ...]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceContractionMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_length,
|
||||
const std::vector<index_t>& e_ms_ns_stride,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeDataType = ADataType>
|
||||
struct DeviceContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvBwdData : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* p_in,
|
||||
const void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
49
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
Normal file
49
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwdBiasActivation : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
const void* p_bias,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdBiasActivationPtr =
|
||||
std::unique_ptr<DeviceConvFwdBiasActivation<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwdBiasActivationAdd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
const void* p_bias,
|
||||
const void* p_resi,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdBiasActivationAddPtr =
|
||||
std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \brief Convolution Tensor Rearrange.
|
||||
*
|
||||
* This Device operator supports converting an image to
|
||||
* the GEMM representation (Image to Column) and
|
||||
* converting a GEMM form to the image (Column to Image).
|
||||
* Supported layouts:
|
||||
* [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C]
|
||||
* [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C]
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ImageLayout Input Layout.
|
||||
* \tparam InputDataType Input Data Type.
|
||||
* \tparam OutputDataType Output Data Type.
|
||||
* \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage.
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
typename OutputDataType,
|
||||
typename ConvTensorRearrangeOp>
|
||||
struct DeviceConvTensorRearrange : public BaseOperator
|
||||
{
|
||||
|
||||
/**
|
||||
* \brief Make argument pointer for image to column.
|
||||
*
|
||||
* \param p_in A pointer to the device memory of the input image.
|
||||
* \param p_out A pointer to the device memory of the output.
|
||||
* \param G Convolution number of groups.
|
||||
* \param N Convolution batch size.
|
||||
* \param C Convolution number of channels.
|
||||
* \param input_spatial_lengths Input spatial lengths.
|
||||
* \param filter_spatial_lengths Filter spatial lengths.
|
||||
* \param output_spatial_lengths Output spatial lengths.
|
||||
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
|
||||
* \param gemm_g_m_k_strides Gemm form strides.
|
||||
* \param conv_filter_strides Convolution filter strides.
|
||||
* \param conv_filter_dilations Convolution filter dilations.
|
||||
* \param input_left_pads Convolution left pads.
|
||||
* \param input_right_pads Convolution right pads.
|
||||
* \return Pointer to the argument.
|
||||
*/
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
void* p_out,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <array>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim>
|
||||
struct DeviceElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim>
|
||||
using DeviceElementwisePtr = std::unique_ptr<
|
||||
DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
typename XElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceElementwiseNormalization : public BaseOperator
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::array<std::vector<index_t>, NumInput> inStridesArray,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
double epsilon,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
XElementwiseOperation x_elementwise_op,
|
||||
YElementwiseOperation y_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
typename XElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceElementwiseNormalizationPtr =
|
||||
std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
XElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <array>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \note This structure is deprecated (left for backwards compatibility). Please use
|
||||
* DeviceElementwise from device_elementwise.hpp.
|
||||
*/
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim>
|
||||
struct DeviceElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op,
|
||||
UnaryOperation unary_op,
|
||||
Scale scale_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
typename UnaryOperation,
|
||||
typename Scale,
|
||||
index_t NumDim>
|
||||
using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
|
||||
OutDataTypeTuple,
|
||||
ElementwiseOperation,
|
||||
UnaryOperation,
|
||||
Scale,
|
||||
NumDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
42
include/ck/tensor_operation/gpu/device/device_gemm.hpp
Normal file
42
include/ck/tensor_operation/gpu/device/device_gemm.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DEGridDesc_M0_M1_M2_N0_N1
|
||||
{
|
||||
ck::index_t M0_, M1_, M2_, N0_, N1_;
|
||||
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
|
||||
};
|
||||
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D)
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmBiasCPermute : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_d,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
|
||||
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline
|
||||
// As input tensor thread buffer declared inside blockwise-gemm pipeline.
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_dequantB : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A0[M, K], B0[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
std::array<ck::index_t, NumATensor> StrideAs,
|
||||
std::array<ck::index_t, NumBTensor> StrideBs,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,154 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <array>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/array.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleDSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_ABScale : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K]
|
||||
// input : B[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// output : H[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// H = layernorm(E)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
// Calculate mean & variance along N dimension in layernorm(E)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename HLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename HDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename HElementwiseOperation>
|
||||
struct DeviceGemmMultipleDLayernorm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// FIXME: DeviceGemmReduce type need to well define the problem
|
||||
// GEMM:
|
||||
// input : A[AK0, M, AK1]
|
||||
// input : B[AK0, N, AK1]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// output : R0[M], R1[M], ...
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
|
||||
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename RsDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename QsElementwiseOperation,
|
||||
typename RsElementwiseOperation>
|
||||
struct DeviceGemmMultipleDMultipleR : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t NumRTensor = RsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
std::array<void*, NumRTensor> p_rs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
QsElementwiseOperation qs_element_op,
|
||||
RsElementwiseOperation rs_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename RsDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename QsElementwiseOperation,
|
||||
typename RsElementwiseOperation>
|
||||
using DeviceGemmMultipleDMultipleRPtr =
|
||||
std::unique_ptr<DeviceGemmMultipleDMultipleR<ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
RsDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
QsElementwiseOperation,
|
||||
RsElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
50
include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp
Normal file
50
include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmMX : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideAScale,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideBScale,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// FIXME: DeviceGemmReduce type need to well define the problem
|
||||
template <ck::index_t NumDTensor, ck::index_t NumReduce>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_bias,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
std::array<void*, NumReduce> p_reduces,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
std::array<void*, 3> gemm_element_ops,
|
||||
std::array<void*, NumDTensor> d_element_ops,
|
||||
std::array<void*, NumReduce> reduce_in_element_ops,
|
||||
std::array<void*, NumReduce> reduce_out_element_ops,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <ck::index_t NumDTensor, ck::index_t NumReduce>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeType = CDataType>
|
||||
struct DeviceGemmSplitK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeType = CDataType>
|
||||
using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
ComputeType>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmStreamK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t NumSKBlocks = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmStreamKPtr = std::unique_ptr<DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_Streamk_V2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t Streamk_sel,
|
||||
ck::index_t Grid_size,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
153
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
Normal file
153
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
Normal file
@@ -0,0 +1,153 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteA() = 0;
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2R1 : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> DsStrides,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2BScale : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideScaleB,
|
||||
const void* p_b_scale,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2BPreshuffle : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteA() = 0;
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ContractionDesc
|
||||
{
|
||||
std::vector<index_t> a_ms_ks_lengths;
|
||||
std::vector<index_t> a_ms_ks_strides;
|
||||
|
||||
std::vector<index_t> b_ns_ks_lengths;
|
||||
std::vector<index_t> b_ns_ks_strides;
|
||||
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths;
|
||||
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides;
|
||||
|
||||
std::vector<index_t> e_ms_ns_lengths;
|
||||
std::vector<index_t> e_ms_ns_strides;
|
||||
};
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
|
||||
std::vector<void*> p_e_vec,
|
||||
std::vector<ContractionDesc<NumDTensor>> contraction_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Conv backward data multiple D:
|
||||
// input : output image A[G, N, K, Ho, Wo]
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
// output : input image E[G, N, C, Hi, Wi],
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const ck::index_t split_k = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeight : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
void* p_wei,
|
||||
const void* p_out,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DsDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeightMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsLayout::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Convolution Forward:
|
||||
// input : input image A[G, N, C, Hi, Wi],
|
||||
// input : weight B[G, K, C, Y, X],
|
||||
// output : output image E[G, N, K, Ho, Wo]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceGroupedConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
const void* p_wei, // weight
|
||||
void* p_out, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const InElementwiseOperation& in_element_op,
|
||||
const WeiElementwiseOperation& wei_element_op,
|
||||
const OutElementwiseOperation& out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#include <array>
|
||||
#endif
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
template <typename T>
|
||||
using is_tuple = decltype(ck::declval<T&>().IsTuple());
|
||||
#else
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Grouped Convolution Forward
|
||||
*
|
||||
* \details
|
||||
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
|
||||
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
|
||||
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
* output : output image E[G, N, K, Ho, Wo]
|
||||
*
|
||||
* C = a_op(A, A1...) * b_op(B, B1...)
|
||||
* E = cde_op(C, D0, D1, ...)
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ALayout Input layout (also for a1, a2...).
|
||||
* \tparam BLayout Weight layout (also for b1, b2...).
|
||||
* \tparam DsLayout Ds layouts.
|
||||
* \tparam ELayout Output layout.
|
||||
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
|
||||
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
|
||||
* \tparam DsDataType D data types.
|
||||
* \tparam EDataType Output data type.
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
|
||||
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // AComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
typename BComputeType = AComputeType>
|
||||
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
using APointers = ck::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers = ck::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
#else
|
||||
// If DataType is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
ck::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
ck::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
#endif
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
|
||||
/**
|
||||
* \brief Make argument pointer for grouped conv fwd.
|
||||
*
|
||||
* \param p_a A pointer to the input (std::array<const void*, NumA> with
|
||||
pointers for multiple A).
|
||||
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
|
||||
pointers for multiple B).
|
||||
* \param p_ds A pointers to the Ds.
|
||||
* \param p_e A pointers to the output.
|
||||
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
|
||||
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param conv_filter_strides Convolution filter strides.
|
||||
* \param conv_filter_dilations Convolution filter dilations.
|
||||
* \param input_left_pads Input left paddings.
|
||||
* \param input_right_pads Input right paddings.
|
||||
* \param a_element_op A elementwise operation object.
|
||||
* \param b_element_op B elementwise operation object.
|
||||
* \param cde_element_op CDE elementwise operation object.
|
||||
* \return Pointer to the argument.
|
||||
*/
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<long_index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<long_index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \brief Grouped Convolution Forward
|
||||
*
|
||||
* \note This structure is deprecated (left for backwards compatibility). Please use
|
||||
* DeviceGroupedConvFwdMultipleABD.
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ALayout Input layout (also for a1, a2...).
|
||||
* \tparam BLayout Weight layout (also for b1, b2...).
|
||||
* \tparam DsLayout Ds layouts.
|
||||
* \tparam ELayout Output layout.
|
||||
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
|
||||
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
|
||||
* \tparam DsDataType D data types.
|
||||
* \tparam EDataType Output data type.
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeType>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
185
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
Normal file
185
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
Normal file
@@ -0,0 +1,185 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/utility/ignore.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Structure representing single GEMM problem arguments.
|
||||
///
|
||||
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
|
||||
/// point kernel.
|
||||
///
|
||||
/// @tparam NumDTensor The number of D input tensors.
|
||||
///
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GroupedGemmKernelArgument
|
||||
{
|
||||
__host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
|
||||
const void* p_b_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
void* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{p_ds_grid_},
|
||||
p_e_grid{p_e_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_}
|
||||
{
|
||||
}
|
||||
|
||||
const void* p_a_grid;
|
||||
const void* p_b_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
void* p_e_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::stringstream str;
|
||||
for(auto sd : StrideDs)
|
||||
str << sd << ",";
|
||||
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SE:" << StrideE << ", "
|
||||
<< "SDs: {" << str.str() << "}"
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmDesc
|
||||
{
|
||||
ck::index_t M_, N_, K_;
|
||||
ck::index_t stride_A_, stride_B_, stride_C_;
|
||||
|
||||
std::vector<ck::index_t> stride_Ds_;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_a,
|
||||
std::vector<const void*>& p_b,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_ds,
|
||||
std::vector<void*>& p_e,
|
||||
std::vector<GemmDesc>& gemm_desc,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
//---------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the device kernel arguments pointer and may copy data to device.
|
||||
///
|
||||
/// TODO: Add which kernels are using this (TileLoop * FixedNK ??)
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel
|
||||
/// arguments.
|
||||
/// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel
|
||||
/// arguments that should be copied to device memory.
|
||||
///
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
|
||||
void* p_dev_kernel_args,
|
||||
const void* p_host_kernel_args) const
|
||||
{
|
||||
ignore = p_arg;
|
||||
ignore = p_dev_kernel_args;
|
||||
ignore = p_host_kernel_args;
|
||||
|
||||
std::ostringstream err;
|
||||
err << "This function is not implemented by the kernel: " << this->GetTypeString()
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the device kernel arguments pointer and may copy data to device.
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
|
||||
/// arguments.
|
||||
///
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
|
||||
{
|
||||
ignore = p_arg;
|
||||
ignore = p_dev_kernel_args;
|
||||
|
||||
std::ostringstream err;
|
||||
err << "This function is not implemented by the kernel: " << this->GetTypeString()
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Gets the device kernel argument size.
|
||||
///
|
||||
/// @param[in] p_arg The pointer to the Device op Argument.
|
||||
///
|
||||
/// @return The device kernel argument size.
|
||||
///
|
||||
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
|
||||
{
|
||||
ignore = p_arg;
|
||||
|
||||
std::ostringstream err;
|
||||
err << "This function is not implemented by the kernel: " << this->GetTypeString()
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "device_grouped_gemm_splitk.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemmFixedNK : DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct GemmMultiABDDesc
|
||||
{
|
||||
ck::index_t M_, N_, K_;
|
||||
|
||||
std::vector<ck::index_t> stride_As_;
|
||||
std::vector<ck::index_t> stride_Bs_;
|
||||
std::vector<ck::index_t> stride_Ds_;
|
||||
|
||||
ck::index_t stride_C_;
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Grouped Gemm Multi ABD
|
||||
*
|
||||
* C = a_op(A, A1...) * b_op(B, B1...)
|
||||
* E = cde_op(C, D0, D1, ...)
|
||||
*
|
||||
* \tparam AsLayout A layouts (tuple).
|
||||
* \tparam BsLayout B layouts (tuple).
|
||||
* \tparam DsLayout Ds layouts (tuple).
|
||||
* \tparam ELayout Output layout.
|
||||
* \tparam AsDataType A data types (tuple).
|
||||
* \tparam BsDataType B data types (tuple).
|
||||
* \tparam DsDataType D data types (tuple).
|
||||
* \tparam EDataType Output data type.
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation C elementwise operation.
|
||||
*/
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedGemmMultiABD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
|
||||
static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
|
||||
|
||||
/*
|
||||
* \brief Make argument pointer for grouped gemm multi abd.
|
||||
*
|
||||
* \param p_as A pointers to the A.
|
||||
* \param p_bs A pointers to the B.
|
||||
* \param p_ds A pointers to the Ds.
|
||||
* \param p_e A pointers to the E.
|
||||
* \param gemm_desc Gemm descriptors for each group.
|
||||
* \param a_element_op A elementwise operation object.
|
||||
* \param b_element_op B elementwise operation object.
|
||||
* \param cde_element_op CDE elementwise operation object.
|
||||
* \return Pointer to the argument.
|
||||
*/
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
|
||||
std::vector<std::array<const void*, NumBTensor>>& p_bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_ds,
|
||||
std::vector<void*>& p_e,
|
||||
std::vector<GemmMultiABDDesc>& gemm_desc,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual void SetElementwiseOps(BaseArgument* p_arg,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
|
||||
#include "device_grouped_gemm_multi_abd.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
|
||||
struct GroupedGemmMultiABDKernelArgument
|
||||
{
|
||||
std::array<const void*, NumATensor> p_as_grid;
|
||||
std::array<const void*, NumBTensor> p_bs_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
void* p_e_grid;
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
|
||||
std::array<index_t, NumATensor> StrideAs;
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Grouped Gemm Multi ABD Fixed NK
|
||||
*
|
||||
* C = a_op(A, A1...) * b_op(B, B1...)
|
||||
* E = cde_op(C, D0, D1, ...)
|
||||
*
|
||||
* \tparam AsLayout A layouts (tuple).
|
||||
* \tparam BsLayout B layouts (tuple).
|
||||
* \tparam DsLayout Ds layouts (tuple).
|
||||
* \tparam ELayout Output layout.
|
||||
* \tparam AsDataType A data types (tuple).
|
||||
* \tparam BsDataType B data types (tuple).
|
||||
* \tparam DsDataType D data types (tuple).
|
||||
* \tparam EDataType Output data type.
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation C elementwise operation.
|
||||
*/
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
|
||||
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
|
||||
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
struct ProblemDesc
|
||||
{
|
||||
std::vector<index_t> a_gs_ms_ks_lengths;
|
||||
std::vector<index_t> a_gs_ms_ks_strides;
|
||||
|
||||
std::vector<index_t> b0_gs_ns_ks_lengths;
|
||||
std::vector<index_t> b0_gs_ns_ks_strides;
|
||||
|
||||
std::vector<index_t> b1_gs_os_ns_lengths;
|
||||
std::vector<index_t> b1_gs_os_ns_strides;
|
||||
|
||||
std::vector<index_t> c_gs_ms_os_lengths;
|
||||
std::vector<index_t> c_gs_ms_os_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b0_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the k batch size.
|
||||
///
|
||||
/// @param p_arg Pointer to the Argument we're going to change.
|
||||
/// @param[in] kbatch The kbatch value.
|
||||
///
|
||||
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the k batch size.
|
||||
///
|
||||
/// @param p_arg Pointer to the Argument we're going to change.
|
||||
/// @param[in] kbatch The kbatch value.
|
||||
///
|
||||
virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const
|
||||
{
|
||||
this->SetKBatchSize(p_arg, kbatch);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// @brief Grouped GEMM kernel using output Tile Looping algorithm
|
||||
///
|
||||
/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K)
|
||||
/// It requires only the number of groups to launch. Other information like
|
||||
/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic
|
||||
/// (known only at kernel run-time).
|
||||
///
|
||||
/// @note This kernel does not support SplitK.
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// For pooling which used indexable operation, such as MaxPool, MinPool...etc
|
||||
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
|
||||
struct DeviceMaxPoolBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
const void* p_indices,
|
||||
void* p_din,
|
||||
index_t dout_length,
|
||||
index_t din_length,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t NumReduction,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple>
|
||||
struct DeviceMultipleReduce : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumInputDim = Rank;
|
||||
static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, NumInputDim> inLengths,
|
||||
const std::array<index_t, NumInputDim> inStrides,
|
||||
const std::array<index_t, NumOutputDim> outLengths,
|
||||
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const std::array<double, NumReduction> alphas,
|
||||
const std::array<double, NumReduction> betas,
|
||||
const void* in_dev,
|
||||
const std::array<void*, NumReduction> out_dev_buffers,
|
||||
const InElementwiseOperationTuple in_elementwise_op_tuple,
|
||||
const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t NumReduction,
|
||||
typename InElementwiseOperationTuple,
|
||||
typename AccElementwiseOperationTuple>
|
||||
using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
|
||||
NumReduceDim,
|
||||
NumReduction,
|
||||
InElementwiseOperationTuple,
|
||||
AccElementwiseOperationTuple>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DXDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalizationBwdData : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> dyStrides,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> meanStrides,
|
||||
const std::vector<index_t> invStdStrides,
|
||||
const std::vector<index_t> dxStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
const void* p_dy,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_mean,
|
||||
const void* p_invStd,
|
||||
void* p_dx) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DXDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
MeanInvStdDataType,
|
||||
DXDataType,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalizationBwdGammaBeta : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> dyStrides,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> meanStrides,
|
||||
const std::vector<index_t> invStdStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> dgammaStrides,
|
||||
const std::vector<index_t> dbetaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
const void* p_dy,
|
||||
const void* p_x,
|
||||
const void* p_mean,
|
||||
const void* p_invStd,
|
||||
void* p_dgamma,
|
||||
void* p_dbeta) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationBwdGammaBetaPtr =
|
||||
std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType,
|
||||
XDataType,
|
||||
MeanInvStdDataType,
|
||||
DGammaDataType,
|
||||
DBetaDataType,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename SaveMeanInvStdDataType,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalizationFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> saveMeanStrides,
|
||||
const std::vector<index_t> saveInvStdStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
double epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
void* p_savedMean,
|
||||
void* p_savedInvVar,
|
||||
YElementwiseOperation y_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename SaveMeanInvStdDataType,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationFwdPtr = std::unique_ptr<DeviceNormalizationFwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
SaveMeanInvStdDataType,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
36
include/ck/tensor_operation/gpu/device/device_permute.hpp
Normal file
36
include/ck/tensor_operation/gpu/device/device_permute.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
|
||||
struct DevicePermute : BaseOperator
|
||||
{
|
||||
using Lengths = std::array<index_t, NumDim>;
|
||||
using Strides = Lengths;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const Lengths& in_lengths,
|
||||
const Strides& in_strides,
|
||||
const Lengths& out_lengths,
|
||||
const Strides& out_strides,
|
||||
const void* in_dev_buffer,
|
||||
void* out_dev_buffer,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
47
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
Normal file
47
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t InOutRank,
|
||||
index_t WindowRank,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename InLayout,
|
||||
typename OutLayout,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
struct DevicePoolFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_dev,
|
||||
void* p_out_dev,
|
||||
void* p_out_indices_dev,
|
||||
std::vector<ck::index_t> input_n_c_wis_lengths,
|
||||
std::vector<ck::index_t> window_xs_lengths,
|
||||
std::vector<ck::index_t> output_n_c_wos_lengths,
|
||||
std::vector<ck::index_t> input_n_c_wis_stride,
|
||||
std::vector<ck::index_t> output_n_c_wis_stride,
|
||||
std::vector<ck::index_t> indices_n_c_wis_stride,
|
||||
std::vector<ck::index_t> window_xs_strides,
|
||||
std::vector<ck::index_t> window_xs_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::index_t> pooling_dims) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// output[indices] = input
|
||||
template <typename InDataType,
|
||||
typename IndexDataType,
|
||||
typename OutDataType,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum Op>
|
||||
struct DevicePutElement : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_input,
|
||||
const void* p_indices,
|
||||
void* p_output,
|
||||
index_t input_length,
|
||||
index_t output_length,
|
||||
ElementwiseOperation elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
71
include/ck/tensor_operation/gpu/device/device_reduce.hpp
Normal file
71
include/ck/tensor_operation/gpu/device/device_reduce.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool OutputIndex>
|
||||
struct DeviceReduce : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<index_t, NumOutDim> outLengths,
|
||||
const std::array<index_t, NumOutDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
const void* in_index_dev,
|
||||
void* out_dev,
|
||||
void* out_index_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
bool PropagateNan,
|
||||
bool OutputIndex>
|
||||
using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
OutputIndex>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceReduceMultiD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
|
||||
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
|
||||
const std::array<index_t, NumOutDim> outLengths,
|
||||
const std::array<index_t, NumOutDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const void* in_dev,
|
||||
const std::array<const void*, NumDTensor> ds_dev,
|
||||
void* out_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation out_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
72
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
72
include/ck/tensor_operation/gpu/device/device_softmax.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOp,
|
||||
typename AccElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceSoftmax : public BaseOperator
|
||||
{
|
||||
//
|
||||
// @brief Makes a pointer to Argument class.
|
||||
//
|
||||
// @param[in] inLengths Input tensor extent(s) from high to low dimension
|
||||
// @param[in] inStrides Input tensor stride(s) from high to low dimension
|
||||
// @param[in] reduceDims The dimension(s) the normalization operation is applied
|
||||
// @param[in] alpha double type value
|
||||
// @param[in] beta double type value
|
||||
// @param[in] in_dev Typeless const pointer in device memory storing the input
|
||||
// tensor
|
||||
// @param out_dev Typeless pointer in device memory storing the output tensor
|
||||
// @param[in] in_elementwise_op The input elementwise operation.
|
||||
// @param[in] acc_elementwise_op The accumulation elementwise operation.
|
||||
//
|
||||
// @return Unique pointer to the Argument class.
|
||||
//
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
double alpha,
|
||||
double beta,
|
||||
const void* in_dev,
|
||||
void* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
AccElementwiseOp acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOp,
|
||||
typename AccElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceSoftmaxPtr = std::unique_ptr<DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InElementwiseOp,
|
||||
AccElementwiseOp,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceSplitKContractionMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
|
||||
const std::vector<index_t>& e_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t split_k) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct GemmSpecialization
|
||||
{
|
||||
// Gemm
|
||||
Default,
|
||||
MPadding,
|
||||
NPadding,
|
||||
KPadding,
|
||||
MNPadding,
|
||||
MKPadding,
|
||||
NKPadding,
|
||||
MNKPadding,
|
||||
// Gemm + Gemm
|
||||
OPadding,
|
||||
MOPadding,
|
||||
NOPadding,
|
||||
KOPadding,
|
||||
MNOPadding,
|
||||
MKOPadding,
|
||||
NKOPadding,
|
||||
MNKOPadding,
|
||||
};
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
478
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
478
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
@@ -0,0 +1,478 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include <fstream>
|
||||
#include <variant>
|
||||
|
||||
// functions to return the corresponding structs based on generated template parameters
|
||||
|
||||
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tensor_layout::convolution::NDHWGK>;
|
||||
// return the layout type: currently this is the only type supported in MIOpen
|
||||
auto layout_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_layout::convolution::NHWGK")
|
||||
{
|
||||
return ck::tensor_layout::convolution::NHWGK{};
|
||||
}
|
||||
throw std::runtime_error("Incorrect layout");
|
||||
}
|
||||
// return the right gemm spec based on the generated template parameters
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
throw std::runtime_error("Incorrect gemm spec: " + type);
|
||||
}
|
||||
|
||||
// return the type of convolution
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
if(type ==
|
||||
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec: " + type);
|
||||
}
|
||||
|
||||
// Function to call on MatrixPadder via a wrapper struct
|
||||
// NOTE: CK only uses MNKPadding for forward convolution
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
auto pad(ck::index_t mpb,
|
||||
ck::index_t npb,
|
||||
ck::index_t kpb,
|
||||
ck::tensor_operation::device::GemmSpecialization gemm,
|
||||
CDesc_MRaw_NRaw conv)
|
||||
{
|
||||
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
ck::tensor_operation::device::MatrixPadder<
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
ck::index_t,
|
||||
ck::index_t,
|
||||
ck::index_t>
|
||||
a;
|
||||
a.MPerTile_ = mpb;
|
||||
a.NPerTile_ = npb;
|
||||
a.KPerTile_ = kpb;
|
||||
auto tmp = grid_desc(a, conv);
|
||||
return tmp;
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters, check gemm spec");
|
||||
}
|
||||
|
||||
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
|
||||
// dims
|
||||
// FIXME: add a way to properly pass in the layout
|
||||
auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 5> dummy_dims;
|
||||
ck::Array<ck::index_t, 2> dummy_spatial_dims;
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 6> dummy_dims;
|
||||
ck::Array<ck::index_t, 3> dummy_spatial_dims;
|
||||
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 4> dummy_dims;
|
||||
ck::Array<ck::index_t, 1> dummy_spatial_dims;
|
||||
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect dims or conv spec");
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
|
||||
{
|
||||
if(m_per_block == 32 && n_per_block == 64)
|
||||
{
|
||||
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 32 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 256)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 256 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters");
|
||||
}
|
||||
|
||||
// wrapper functions by dims to get grid size - uses above 3 functions
|
||||
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
|
||||
auto get_launch_params_1d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params_3d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,523 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// In and Din = [N, C, Hi, Wi]
|
||||
// Out and Dout = [N, C, Ho, Wo]
|
||||
// Out = AvgPool2dFwd(In)
|
||||
// Din = AvgPool2dBwd(Dout)
|
||||
// Pooling dimension = H, W
|
||||
template <typename DOutDataType,
|
||||
typename DInDataType,
|
||||
typename ComputeDataType,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MThreadClusterSize,
|
||||
ck::index_t KThreadClusterSize,
|
||||
ck::index_t MThreadSliceSize,
|
||||
ck::index_t KThreadSliceSize,
|
||||
ck::index_t InSrcOutDstVectorSize>
|
||||
struct DeviceAvgPool2dBwd_NHWC_NHWC : public DeviceAvgPoolBwd<2,
|
||||
DOutDataType,
|
||||
DInDataType,
|
||||
tensor_layout::convolution::NHWC,
|
||||
tensor_layout::convolution::NHWC>
|
||||
{
|
||||
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto
|
||||
Make2DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
|
||||
const std::vector<ck::index_t>& din_n_c_wos_length,
|
||||
const std::vector<ck::index_t>& dout_n_c_wos_strides,
|
||||
const std::vector<ck::index_t>& din_n_c_wos_strides,
|
||||
const std::vector<ck::index_t>& window_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads,
|
||||
const std::vector<ck::index_t>& tildes)
|
||||
{
|
||||
index_t i_ytilde = tildes[0];
|
||||
index_t i_xtilde = tildes[1];
|
||||
|
||||
const index_t N = dout_n_c_wos_lengths[0];
|
||||
const index_t C = dout_n_c_wos_lengths[1];
|
||||
const index_t Ho = dout_n_c_wos_lengths[2];
|
||||
const index_t Wo = dout_n_c_wos_lengths[3];
|
||||
|
||||
const index_t Hi = din_n_c_wos_length[2];
|
||||
const index_t Wi = din_n_c_wos_length[3];
|
||||
|
||||
const index_t Y = window_lengths[0];
|
||||
const index_t X = window_lengths[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ConvStrideH = window_strides[0];
|
||||
const index_t ConvStrideW = window_strides[1];
|
||||
|
||||
const index_t ConvDilationH = window_dilations[0];
|
||||
const index_t ConvDilationW = window_dilations[1];
|
||||
|
||||
const index_t Ni_stride = dout_n_c_wos_strides[0];
|
||||
const index_t Ci_stride = dout_n_c_wos_strides[1];
|
||||
const index_t Ho_stride = dout_n_c_wos_strides[2];
|
||||
const index_t Wo_stride = dout_n_c_wos_strides[3];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on Tildes that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// ReduceK is different for each Reduce
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// Problem size of reduction kernel
|
||||
const index_t MRaw = N * HTildeSlice * WTildeSlice * C;
|
||||
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
|
||||
|
||||
const index_t KRaw = YDotSlice * XDotSlice;
|
||||
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
|
||||
|
||||
const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride));
|
||||
|
||||
// Out[ReduceM, ReduceK]
|
||||
const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}));
|
||||
|
||||
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)),
|
||||
make_merge_transform(make_tuple(YDotSlice, XDotSlice))),
|
||||
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
|
||||
out_grid_desc_reducemraw_reducekraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// In[ReduceM]
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
|
||||
make_tuple(din_n_c_wos_strides[0],
|
||||
din_n_c_wos_strides[2],
|
||||
din_n_c_wos_strides[3],
|
||||
din_n_c_wos_strides[1]));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto in_grid_desc_reducem =
|
||||
transform_tensor_descriptor(in_grid_desc_reducemraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
|
||||
}
|
||||
|
||||
using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0},
|
||||
{0, 0, 0, 0},
|
||||
{0, 0, 0, 0},
|
||||
{0, 0, 0, 0},
|
||||
{0, 0},
|
||||
{0, 0},
|
||||
{0, 0},
|
||||
{0, 0},
|
||||
{0, 0},
|
||||
{0, 0}));
|
||||
|
||||
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
|
||||
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
|
||||
|
||||
// FIXME
|
||||
// for NHWC, the dim C is the fastest dimension, and is not reduced.
|
||||
// Hence, it is in M dimension for reduction kernel.
|
||||
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Div = tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
|
||||
DInDataType,
|
||||
ComputeDataType,
|
||||
int,
|
||||
DoutGridDesc_M_K,
|
||||
DinGridDesc_M,
|
||||
reduce::Add,
|
||||
PassThrough,
|
||||
Div,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // propagate_nan
|
||||
BlockSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
OutSrcInDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const DOutDataType* p_dout,
|
||||
DInDataType* p_din,
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths,
|
||||
std::vector<ck::index_t> din_n_c_wos_length,
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides,
|
||||
std::vector<ck::index_t> din_n_c_wos_strides,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
: p_dout_grid_{p_dout},
|
||||
p_din_grid_{p_din},
|
||||
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
|
||||
din_n_c_wos_length_{din_n_c_wos_length},
|
||||
dout_n_c_wos_strides_{dout_n_c_wos_strides},
|
||||
din_n_c_wos_strides_{din_n_c_wos_strides},
|
||||
num_reduce_{1},
|
||||
div_element_op_{window_lengths[0] * window_lengths[1]}
|
||||
{
|
||||
std::vector<ck::index_t> Tildes(NDimSpatial);
|
||||
for(int i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
|
||||
Tildes[i] = window_strides[i] / GcdStrideDilation;
|
||||
num_reduce_ *= Tildes[i];
|
||||
}
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde)
|
||||
{
|
||||
const auto YDotSlice =
|
||||
math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]);
|
||||
const auto XDotSlice =
|
||||
math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]);
|
||||
|
||||
if(YDotSlice * XDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto dout_din_grid_desc =
|
||||
Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
|
||||
din_n_c_wos_length,
|
||||
dout_n_c_wos_strides,
|
||||
din_n_c_wos_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ytilde, i_xtilde});
|
||||
|
||||
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
|
||||
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DOutDataType* p_dout_grid_;
|
||||
DInDataType* p_din_grid_;
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths_;
|
||||
std::vector<ck::index_t> din_n_c_wos_length_;
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides_;
|
||||
std::vector<ck::index_t> din_n_c_wos_strides_;
|
||||
|
||||
int num_reduce_;
|
||||
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
|
||||
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
|
||||
|
||||
Div div_element_op_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i = 0; i < arg.num_reduce_; i++)
|
||||
{
|
||||
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
|
||||
false,
|
||||
false,
|
||||
false, // don't have index input
|
||||
DOutDataType,
|
||||
DInDataType,
|
||||
ComputeDataType,
|
||||
int,
|
||||
DoutGridDesc_M_K,
|
||||
DinGridDesc_M,
|
||||
PassThrough,
|
||||
Div>;
|
||||
|
||||
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
|
||||
const index_t grid_size = (M / M_BlockTileSize);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.dout_grid_desc_m_k_container_[i],
|
||||
arg.din_grid_desc_m_container_[i],
|
||||
PassThrough{},
|
||||
arg.div_element_op_,
|
||||
float(1),
|
||||
arg.p_dout_grid_,
|
||||
nullptr,
|
||||
float(0),
|
||||
arg.p_din_grid_,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
constexpr index_t Rank = NDimSpatial + 2;
|
||||
int doutFastestDim = -1;
|
||||
int dinFastestDim = -1;
|
||||
|
||||
for(int i = 0; i < Rank; ++i)
|
||||
{
|
||||
if(arg.dout_n_c_wos_strides_[i] == 1)
|
||||
doutFastestDim = i;
|
||||
if(arg.din_n_c_wos_strides_[i] == 1)
|
||||
dinFastestDim = i;
|
||||
}
|
||||
if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(doutFastestDim == -1 || dinFastestDim == -1)
|
||||
{
|
||||
if constexpr(InSrcOutDstVectorSize != 1)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
void* p_din,
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths,
|
||||
std::vector<ck::index_t> din_n_c_wos_length,
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides,
|
||||
std::vector<ck::index_t> din_n_c_wos_strides,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) override
|
||||
{
|
||||
constexpr index_t Rank = NDimSpatial + 2;
|
||||
|
||||
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
|
||||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
|
||||
{
|
||||
throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or "
|
||||
"[dout|din]_n_c_wos_lengths is not equal to Rank");
|
||||
}
|
||||
|
||||
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
|
||||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
|
||||
input_right_pads.size() != NDimSpatial)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
|
||||
"input_right_pads] is not equal to Rank");
|
||||
}
|
||||
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
|
||||
static_cast<DInDataType*>(p_din),
|
||||
dout_n_c_wos_lengths,
|
||||
din_n_c_wos_length,
|
||||
dout_n_c_wos_strides,
|
||||
din_n_c_wos_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceAvgPool2dBwd<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,575 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// In and Din = [N, C, Di, Hi, Wi]
|
||||
// Out and Dout = [N, C, Do, Ho, Wo]
|
||||
// Out = AvgPoolFwd(In)
|
||||
// Din = AvgPoolBwd(Dout)
|
||||
// Pooling dimension = D, H, W
|
||||
template <typename DOutDataType,
|
||||
typename DInDataType,
|
||||
typename ComputeDataType,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MThreadClusterSize,
|
||||
ck::index_t KThreadClusterSize,
|
||||
ck::index_t MThreadSliceSize,
|
||||
ck::index_t KThreadSliceSize,
|
||||
ck::index_t InSrcOutDstVectorSize>
|
||||
struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3,
|
||||
DOutDataType,
|
||||
DInDataType,
|
||||
tensor_layout::convolution::NDHWC,
|
||||
tensor_layout::convolution::NDHWC>
|
||||
{
|
||||
static constexpr ck::index_t NDimSpatial = 3;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto
|
||||
Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
|
||||
const std::vector<ck::index_t>& din_n_c_wos_length,
|
||||
const std::vector<ck::index_t>& dout_n_c_wos_strides,
|
||||
const std::vector<ck::index_t>& din_n_c_wos_strides,
|
||||
const std::vector<ck::index_t>& window_lengths,
|
||||
const std::vector<ck::index_t>& window_strides,
|
||||
const std::vector<ck::index_t>& window_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads,
|
||||
const std::vector<ck::index_t>& tildes)
|
||||
{
|
||||
index_t i_ztilde = tildes[0];
|
||||
index_t i_ytilde = tildes[1];
|
||||
index_t i_xtilde = tildes[2];
|
||||
|
||||
const index_t N = dout_n_c_wos_lengths[0];
|
||||
const index_t C = dout_n_c_wos_lengths[1];
|
||||
|
||||
const index_t Di = din_n_c_wos_length[2];
|
||||
const index_t Hi = din_n_c_wos_length[3];
|
||||
const index_t Wi = din_n_c_wos_length[4];
|
||||
|
||||
const index_t Do = dout_n_c_wos_lengths[2];
|
||||
const index_t Ho = dout_n_c_wos_lengths[3];
|
||||
const index_t Wo = dout_n_c_wos_lengths[4];
|
||||
|
||||
const index_t Z = window_lengths[0];
|
||||
const index_t Y = window_lengths[1];
|
||||
const index_t X = window_lengths[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const index_t ConvStrideD = window_strides[0];
|
||||
const index_t ConvStrideH = window_strides[1];
|
||||
const index_t ConvStrideW = window_strides[2];
|
||||
|
||||
const index_t ConvDilationD = window_dilations[0];
|
||||
const index_t ConvDilationH = window_dilations[1];
|
||||
const index_t ConvDilationW = window_dilations[2];
|
||||
|
||||
const auto out_n_do_ho_wo_c_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C),
|
||||
make_tuple(dout_n_c_wos_strides[0],
|
||||
dout_n_c_wos_strides[2],
|
||||
dout_n_c_wos_strides[3],
|
||||
dout_n_c_wos_strides[4],
|
||||
dout_n_c_wos_strides[1]));
|
||||
|
||||
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
|
||||
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on Tildes that contribute to non-padding area of input tensor
|
||||
const auto IDTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IDTildeSliceEnd =
|
||||
math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
|
||||
const auto IHTildeSliceEnd =
|
||||
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd =
|
||||
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// ReduceK is different for each Reduce
|
||||
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// Problem size of reduction kernel
|
||||
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
|
||||
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
|
||||
|
||||
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
|
||||
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
|
||||
|
||||
// Out[ReduceM, ReduceK]
|
||||
const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
|
||||
out_n_do_ho_wo_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Do, I0, I0),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_dop_hop_wop_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(ZDot, DTilde),
|
||||
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(ZDot, I0, ZDotSlice),
|
||||
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
|
||||
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
|
||||
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
|
||||
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
|
||||
out_grid_desc_reducemraw_reducekraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// In[ReduceM]
|
||||
const auto in_n_di_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(din_n_c_wos_strides[0],
|
||||
din_n_c_wos_strides[2],
|
||||
din_n_c_wos_strides[3],
|
||||
din_n_c_wos_strides[4],
|
||||
din_n_c_wos_strides[1]));
|
||||
|
||||
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(XTilde, DTilde),
|
||||
make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ztilde),
|
||||
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
|
||||
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto in_grid_desc_reducem =
|
||||
transform_tensor_descriptor(in_grid_desc_reducemraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
|
||||
}
|
||||
|
||||
using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}));
|
||||
|
||||
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
|
||||
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
|
||||
|
||||
// FIXME
|
||||
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
|
||||
// Hence, it is in M dimension for reduction kernel.
|
||||
static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Div = tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
|
||||
DInDataType,
|
||||
ComputeDataType,
|
||||
int,
|
||||
DoutGridDesc_M_K,
|
||||
DinGridDesc_M,
|
||||
reduce::Add,
|
||||
PassThrough,
|
||||
Div,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // propagate_nan
|
||||
BlockSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
OutSrcInDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const DOutDataType* p_dout,
|
||||
DInDataType* p_din,
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths,
|
||||
std::vector<ck::index_t> din_n_c_wos_length,
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides,
|
||||
std::vector<ck::index_t> din_n_c_wos_strides,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
: p_dout_grid_{p_dout},
|
||||
p_din_grid_{p_din},
|
||||
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
|
||||
din_n_c_wos_length_{din_n_c_wos_length},
|
||||
dout_n_c_wos_strides_{dout_n_c_wos_strides},
|
||||
din_n_c_wos_strides_{din_n_c_wos_strides},
|
||||
num_reduce_{1},
|
||||
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
|
||||
{
|
||||
std::vector<ck::index_t> Tildes(NDimSpatial);
|
||||
for(int i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
|
||||
Tildes[i] = window_strides[i] / GcdStrideDilation;
|
||||
num_reduce_ *= Tildes[i];
|
||||
}
|
||||
|
||||
for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde)
|
||||
{
|
||||
for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde)
|
||||
{
|
||||
// check slice is valid
|
||||
const auto ZDotSlice =
|
||||
math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]);
|
||||
const auto YDotSlice =
|
||||
math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]);
|
||||
const auto XDotSlice =
|
||||
math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]);
|
||||
|
||||
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto dout_din_grid_desc =
|
||||
Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
|
||||
din_n_c_wos_length,
|
||||
dout_n_c_wos_strides,
|
||||
din_n_c_wos_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{i_ztilde, i_ytilde, i_xtilde});
|
||||
|
||||
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
|
||||
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DOutDataType* p_dout_grid_;
|
||||
DInDataType* p_din_grid_;
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths_;
|
||||
std::vector<ck::index_t> din_n_c_wos_length_;
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides_;
|
||||
std::vector<ck::index_t> din_n_c_wos_strides_;
|
||||
|
||||
int num_reduce_;
|
||||
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
|
||||
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
|
||||
|
||||
Div div_element_op_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i = 0; i < arg.num_reduce_; i++)
|
||||
{
|
||||
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
|
||||
false,
|
||||
false,
|
||||
false, // don't have index input
|
||||
DOutDataType,
|
||||
DInDataType,
|
||||
ComputeDataType,
|
||||
int,
|
||||
DoutGridDesc_M_K,
|
||||
DinGridDesc_M,
|
||||
PassThrough,
|
||||
Div>;
|
||||
|
||||
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
|
||||
const index_t grid_size = (M / M_BlockTileSize);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.dout_grid_desc_m_k_container_[i],
|
||||
arg.din_grid_desc_m_container_[i],
|
||||
PassThrough{},
|
||||
arg.div_element_op_,
|
||||
float(1),
|
||||
arg.p_dout_grid_,
|
||||
nullptr,
|
||||
float(0),
|
||||
arg.p_din_grid_,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
constexpr index_t Rank = NDimSpatial + 2;
|
||||
int doutFastestDim = -1;
|
||||
int dinFastestDim = -1;
|
||||
|
||||
for(int i = 0; i < Rank; ++i)
|
||||
{
|
||||
if(arg.dout_n_c_wos_strides_[i] == 1)
|
||||
doutFastestDim = i;
|
||||
if(arg.din_n_c_wos_strides_[i] == 1)
|
||||
dinFastestDim = i;
|
||||
}
|
||||
|
||||
if(doutFastestDim == -1 || dinFastestDim == -1)
|
||||
{
|
||||
if constexpr(InSrcOutDstVectorSize != 1)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_dout,
|
||||
void* p_din,
|
||||
std::vector<ck::index_t> dout_n_c_wos_lengths,
|
||||
std::vector<ck::index_t> din_n_c_wos_length,
|
||||
std::vector<ck::index_t> dout_n_c_wos_strides,
|
||||
std::vector<ck::index_t> din_n_c_wos_strides,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) override
|
||||
{
|
||||
constexpr index_t Rank = NDimSpatial + 2;
|
||||
|
||||
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
|
||||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
|
||||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
|
||||
input_right_pads.size() != NDimSpatial)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
|
||||
static_cast<DInDataType*>(p_din),
|
||||
dout_n_c_wos_lengths,
|
||||
din_n_c_wos_length,
|
||||
dout_n_c_wos_strides,
|
||||
din_n_c_wos_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,696 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
|
||||
\link
|
||||
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
|
||||
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename EDataType,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
ck::Tuple<>{},
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ck::Tuple<>{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumPrefetch,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmEPermuteXdl;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
|
||||
{
|
||||
const auto e_grid_desc_mraw_nraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
|
||||
index_t G1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t stride_G0,
|
||||
index_t stride_G1,
|
||||
index_t stride_M,
|
||||
index_t stride_N)
|
||||
{
|
||||
const auto e_grid_desc_g0_g1_mraw_nraw = [&]() {
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(G0, G1, MRaw, NRaw),
|
||||
make_tuple(stride_G0, stride_G1, stride_M, stride_N));
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_g0_g1_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(G0),
|
||||
make_pass_through_transform(G1),
|
||||
make_pass_through_transform(MRaw),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return e_grid_desc_g0_g1_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
|
||||
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
|
||||
index_t Batchstride_B,
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
|
||||
: Batchstride_A_(Batchstride_A),
|
||||
Batchstride_B_(Batchstride_B),
|
||||
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(Batchstride_A_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(Batchstride_B_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
|
||||
index_t b0 = g_idx / G1;
|
||||
index_t b1 = g_idx - b0 * G1; // g_idx % G1
|
||||
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
index_t Batchstride_A_;
|
||||
index_t Batchstride_B_;
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
|
||||
};
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<>, // DsDataType,
|
||||
EDataType, // EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_N_K,
|
||||
Tuple<>,
|
||||
EGridDesc_M_N,
|
||||
NumPrefetch,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}));
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
EDataType* p_e_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_e_grid_{p_e_grid},
|
||||
BatchCount_(BatchCount),
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)},
|
||||
e_grid_desc_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
|
||||
batched_gemm_e_permute_desc.N_,
|
||||
batched_gemm_e_permute_desc.stride_M_,
|
||||
batched_gemm_e_permute_desc.stride_N_)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_g0_g1_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
|
||||
batched_gemm_e_permute_desc.G1_,
|
||||
batched_gemm_e_permute_desc.M_,
|
||||
batched_gemm_e_permute_desc.N_,
|
||||
batched_gemm_e_permute_desc.stride_G0_,
|
||||
batched_gemm_e_permute_desc.stride_G1_,
|
||||
batched_gemm_e_permute_desc.stride_M_,
|
||||
batched_gemm_e_permute_desc.stride_N_)},
|
||||
compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// batch count
|
||||
index_t BatchCount_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
|
||||
|
||||
// for calculating Batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
|
||||
"setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_batched_gemm_e_permute_xdl<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
EDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_AK0_M_AK1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_BK0_N_BK1>,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
remove_reference_t<Block2ETileMap>,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
ck::Tuple<>{},
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
EDataType* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batched_gemm_e_permute_desc,
|
||||
BatchCount,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
|
||||
index_t BatchCount,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batched_gemm_e_permute_desc,
|
||||
BatchCount,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmEPermuteXdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,747 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_b1_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = acc_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock, // Gemm0NPerBlock
|
||||
index_t KPerBlock, // Gemm0KPerBlock
|
||||
index_t Gemm1NPerBlock,
|
||||
index_t Gemm1KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t B1K1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
index_t Gemm1NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
index_t B1BlockTransferSrcVectorDim,
|
||||
index_t B1BlockTransferSrcScalarPerVector,
|
||||
index_t B1BlockTransferDstScalarPerVector_BK1,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw, // = ORaw
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
b1_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
acc_element_op_{acc_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!DeviceOp::IsSupportedArgument(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported argument");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.acc_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
// to concern Gemm0's loop
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
|
||||
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a, p_b, p_b1, p_c, MRaw,
|
||||
NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
|
||||
StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
|
||||
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
|
||||
b1_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmGemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,724 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
|
||||
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
|
||||
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultiD_Xdl;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
{
|
||||
const auto e_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
std::array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
|
||||
});
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
Batch_(Batch),
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
|
||||
});
|
||||
|
||||
// populate desc for Ds/E
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Batch
|
||||
index_t Batch_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// for calculating batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceBatchedGemmMultiD_Xdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.Batch_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_xdl<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultiD_Xdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,792 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
* limitations.
|
||||
*
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
|
||||
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
|
||||
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*/
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_K0_M0_M1_K1,
|
||||
typename BGridDesc_K0_N0_N1_K1,
|
||||
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
typename CGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dl_multiple_d(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
|
||||
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
|
||||
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::Run(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m0_m1_k1,
|
||||
b_grid_desc_k0_n0_n1_k1,
|
||||
ds_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
block_2_ctile_map,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = a_grid_desc_k0_m0_m1_k1;
|
||||
ignore = b_grid_desc_k0_n0_n1_k1;
|
||||
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
|
||||
ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_ctile_map;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t K1,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleD_Dl;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(M, PadM)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
std::array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
|
||||
});
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
EGridDesc_M_N,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
using AGridDesc_K0_M0_M1_K1 =
|
||||
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
using BGridDesc_K0_N0_N1_K1 =
|
||||
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
|
||||
using DsGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
|
||||
using EGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
|
||||
using DefaultBlock2CTileMap =
|
||||
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}));
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
K_(K),
|
||||
Batch_(Batch),
|
||||
a_grid_desc_k0_m0_m1_k1_{},
|
||||
b_grid_desc_k0_n0_n1_k1_{},
|
||||
e_grid_desc_m0_m10_m11_n0_n10_n11_{},
|
||||
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
|
||||
block_2_ctile_map_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
|
||||
});
|
||||
e_grid_desc_m_n_ =
|
||||
DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_))
|
||||
{
|
||||
a_grid_desc_k0_m0_m1_k1_ =
|
||||
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
|
||||
b_grid_desc_k0_n0_n1_k1_ =
|
||||
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
|
||||
|
||||
ds_grid_desc_m0_m10_m11_n0_n10_n11_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_);
|
||||
|
||||
e_grid_desc_m0_m10_m11_n0_n10_n11_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
index_t K_;
|
||||
|
||||
// Batch
|
||||
index_t Batch_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
|
||||
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
|
||||
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
|
||||
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
|
||||
|
||||
// for calculating batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceBatchedGemmMultipleD_Dl::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0),
|
||||
arg.e_grid_desc_m_n_.GetLength(I1)) *
|
||||
arg.Batch_;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop,
|
||||
auto has_double_tail_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
|
||||
|
||||
const auto kernel =
|
||||
kernel_gemm_dl_multiple_d<GridwiseGemm,
|
||||
ADataType,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_K0_M0_M1_K1,
|
||||
DeviceOp::BGridDesc_K0_N0_N1_K1,
|
||||
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
DeviceOp::EGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
DefaultBlock2CTileMap,
|
||||
has_main_loop,
|
||||
has_double_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_k0_m0_m1_k1_,
|
||||
arg.b_grid_desc_k0_n0_n1_k1_,
|
||||
arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_,
|
||||
arg.e_grid_desc_m0_m10_m11_n0_n10_n11_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.e_grid_desc_m_n_);
|
||||
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultipleD_Dl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< M1PerThread << ", "
|
||||
<< N1PerThread << ", "
|
||||
<< KPerThread
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,996 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename A0B0B1DataType,
|
||||
typename D0sPointer,
|
||||
typename D1sPointer,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
typename A0GridDesc_AK0_M_AK1,
|
||||
typename B0GridDesc_BK0_N_BK1,
|
||||
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2E1TileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
|
||||
const A0B0B1DataType* __restrict__ p_a0_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b0_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b1_grid,
|
||||
D1sPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
const A0ElementwiseOperation a0_element_op,
|
||||
const B0ElementwiseOperation b0_element_op,
|
||||
const CDE0ElementwiseOperation cde0_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CDE1ElementwiseOperation cde1_element_op,
|
||||
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
|
||||
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2E1TileMap block_2_e1tile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
|
||||
p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a0_grid + a_batch_offset,
|
||||
p_b0_grid + b_batch_offset,
|
||||
p_d0s_grid,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
p_e1_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op,
|
||||
a0_grid_desc_ak0_m_ak1,
|
||||
b0_grid_desc_bk0_n_bk1,
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_e1tile_map);
|
||||
#else
|
||||
ignore = p_a0_grid;
|
||||
ignore = p_b0_grid;
|
||||
ignore = p_d0s_grid;
|
||||
ignore = p_b1_grid;
|
||||
ignore = p_d1s_grid;
|
||||
ignore = p_e1_grid;
|
||||
ignore = a0_element_op;
|
||||
ignore = b0_element_op;
|
||||
ignore = cde0_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = cde1_element_op;
|
||||
ignore = a0_grid_desc_ak0_m_ak1;
|
||||
ignore = b0_grid_desc_bk0_n_bk1;
|
||||
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_e1tile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename A0Layout,
|
||||
typename B0Layout, // B0Layout
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename Acc0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename Acc1DataType,
|
||||
typename C1ShuffleDataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
bool PadGemm0M,
|
||||
bool PadGemm0N,
|
||||
bool PadGemm0K,
|
||||
bool PadGemm1N,
|
||||
bool PadGemm1K,
|
||||
index_t NumGemm0KPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t Gemm0MPerBlock,
|
||||
index_t Gemm0NPerBlock,
|
||||
index_t Gemm0KPerBlock,
|
||||
index_t Gemm1NPerBlock,
|
||||
index_t Gemm1KPerBlock,
|
||||
index_t A0K1,
|
||||
index_t B0K1,
|
||||
index_t B1K1,
|
||||
index_t Gemm0MPerXdl,
|
||||
index_t Gemm0NPerXdl,
|
||||
index_t Gemm0MXdlPerWave,
|
||||
index_t Gemm0NXdlPerWave,
|
||||
index_t Gemm1NXdlPerWave,
|
||||
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename A0BlockTransferThreadClusterArrangeOrder,
|
||||
typename A0BlockTransferSrcAccessOrder,
|
||||
index_t A0BlockTransferSrcVectorDim,
|
||||
index_t A0BlockTransferSrcScalarPerVector,
|
||||
index_t A0BlockTransferDstScalarPerVector_AK1,
|
||||
bool A0BlockLdsExtraM,
|
||||
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B0BlockTransferThreadClusterArrangeOrder,
|
||||
typename B0BlockTransferSrcAccessOrder,
|
||||
index_t B0BlockTransferSrcVectorDim,
|
||||
index_t B0BlockTransferSrcScalarPerVector,
|
||||
index_t B0BlockTransferDstScalarPerVector_BK1,
|
||||
bool B0BlockLdsExtraN,
|
||||
index_t CDE0BlockTransferSrcVectorDim,
|
||||
index_t CDE0BlockTransferSrcScalaerPerVector,
|
||||
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
index_t B1BlockTransferSrcVectorDim,
|
||||
index_t B1BlockTransferSrcScalarPerVector,
|
||||
index_t B1BlockTransferDstScalarPerVector_BK1,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t C1ShuffleMXdlPerWavePerShuffle,
|
||||
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
|
||||
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I9 = Number<9>{};
|
||||
|
||||
static constexpr auto gemm0_padder =
|
||||
GemmPadder_v2<PadGemm0M, PadGemm0N, PadGemm0K, index_t, index_t, index_t>{
|
||||
Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
|
||||
|
||||
static constexpr auto gemm1_padder =
|
||||
GemmPadder_v2<PadGemm0M, PadGemm1N, PadGemm1K, index_t, index_t, index_t>{
|
||||
Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
|
||||
|
||||
// for Gemm0
|
||||
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
|
||||
{
|
||||
const auto a0_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, A0Layout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA0, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, A0Layout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA0));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm0
|
||||
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b0_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B0Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B0Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm0
|
||||
template <typename DLay>
|
||||
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
|
||||
{
|
||||
const auto d0_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideD0, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideD0));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// for Gemm1
|
||||
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm1
|
||||
template <typename ELay>
|
||||
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
|
||||
{
|
||||
const auto e1_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideE1, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideE1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
|
||||
const std::array<index_t, NumD1Tensor>& NRaws,
|
||||
const std::array<index_t, NumD1Tensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
|
||||
|
||||
return DeviceOp::MakeD0GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
|
||||
const std::array<index_t, NumD1Tensor>& NRaws,
|
||||
const std::array<index_t, NumD1Tensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
|
||||
|
||||
return DeviceOp::MakeE1GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA0_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideE1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA0_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideE1_;
|
||||
};
|
||||
|
||||
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1));
|
||||
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1));
|
||||
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))>;
|
||||
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1));
|
||||
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))>;
|
||||
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N<E1Layout>(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
|
||||
A0DataType, // TODO: distinguish A/B datatype
|
||||
Acc0DataType,
|
||||
D0sDataType,
|
||||
Acc1DataType,
|
||||
C1ShuffleDataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
A0GridDesc_M_K,
|
||||
B0GridDesc_N_K,
|
||||
D0sGridDesc_M_N,
|
||||
B1GridDesc_N_K,
|
||||
D1sGridDesc_M_N,
|
||||
E1GridDesc_M_N,
|
||||
NumGemm0KPrefetchStage,
|
||||
BlockSize,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
A0K1,
|
||||
B0K1,
|
||||
B1K1,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MXdlPerWave,
|
||||
Gemm0NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
A0BlockTransferThreadClusterArrangeOrder,
|
||||
A0BlockTransferSrcAccessOrder,
|
||||
A0BlockTransferSrcVectorDim,
|
||||
A0BlockTransferSrcScalarPerVector,
|
||||
A0BlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
A0BlockLdsExtraM,
|
||||
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B0BlockTransferThreadClusterArrangeOrder,
|
||||
B0BlockTransferSrcAccessOrder,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
B0BlockLdsExtraN,
|
||||
CDE0BlockTransferSrcVectorDim,
|
||||
CDE0BlockTransferSrcScalaerPerVector,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
C1ShuffleMXdlPerWavePerShuffle,
|
||||
C1ShuffleGemm0NXdlPerWavePerShuffle,
|
||||
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using A0GridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
|
||||
A0GridDesc_M_K{}))>;
|
||||
using B0GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
|
||||
B0GridDesc_N_K{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
|
||||
B1GridDesc_N_K{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const A0DataType* p_a0_grid,
|
||||
const B0DataType* p_b0_grid,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid,
|
||||
E1DataType* p_e1_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw, // = ORaw
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
: p_a0_grid_{p_a0_grid},
|
||||
p_b0_grid_{p_b0_grid},
|
||||
p_d0s_grid_{},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_d1s_grid_{},
|
||||
p_e1_grid_{p_e1_grid},
|
||||
a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)},
|
||||
b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)},
|
||||
d0s_grid_desc_m_n_{},
|
||||
b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)},
|
||||
d1s_grid_desc_m_n_{},
|
||||
e1_grid_desc_m_n_{
|
||||
DeviceOp::MakeE1GridDescriptor_M_N<E1Layout>(MRaw, Gemm1NRaw, StrideE1)},
|
||||
a0_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)},
|
||||
b0_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)},
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
|
||||
b1_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)},
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)},
|
||||
a0_element_op_{a0_element_op},
|
||||
b0_element_op_{b0_element_op},
|
||||
cde0_element_op_{cde0_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
cde1_element_op_{cde1_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
|
||||
<< ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_d0s_grid[i]);
|
||||
|
||||
// D0 desc
|
||||
d0s_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeD0GridDescriptor_M_N<D0Layout>(MRaw, NRaw, StrideD0s[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1Layout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid_(i) = static_cast<const D1DataType*>(p_d1s_grid[i]);
|
||||
|
||||
// D1 desc
|
||||
d1s_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeE1GridDescriptor_M_N<D1Layout>(MRaw, Gemm1NRaw, StrideD1s[i]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_,
|
||||
b0_grid_desc_n_k_,
|
||||
b1_grid_desc_n_k_,
|
||||
e1_grid_desc_m_n_,
|
||||
block_2_e1tile_map_))
|
||||
{
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e1_grid_desc_m_n_);
|
||||
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
|
||||
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
d0s_grid_desc_m_n_);
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
d1s_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const A0DataType* p_a0_grid_;
|
||||
const B0DataType* p_b0_grid_;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
typename GridwiseGemm::D1sGridPointer p_d1s_grid_;
|
||||
E1DataType* p_e1_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
A0GridDesc_M_K a0_grid_desc_m_k_;
|
||||
B0GridDesc_N_K b0_grid_desc_n_k_;
|
||||
D0sGridDesc_M_N d0s_grid_desc_m_n_;
|
||||
B1GridDesc_N_K b1_grid_desc_n_k_;
|
||||
D1sGridDesc_M_N d1s_grid_desc_m_n_;
|
||||
E1GridDesc_M_N e1_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_;
|
||||
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e1-tile map
|
||||
typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_;
|
||||
|
||||
// element-wise op
|
||||
A0ElementwiseOperation a0_element_op_;
|
||||
B0ElementwiseOperation b0_element_op_;
|
||||
CDE0ElementwiseOperation cde0_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CDE1ElementwiseOperation cde1_element_op_;
|
||||
|
||||
// batch
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
|
||||
arg.b0_grid_desc_n_k_,
|
||||
arg.b1_grid_desc_n_k_,
|
||||
arg.e1_grid_desc_m_n_,
|
||||
arg.block_2_e1tile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K = arg.a0_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
A0DataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::D0sGridPointer,
|
||||
typename GridwiseGemm::D1sGridPointer,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
DeviceOp::A0GridDesc_AK0_M_AK1,
|
||||
DeviceOp::B0GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2E1TileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a0_grid_,
|
||||
arg.p_b0_grid_,
|
||||
arg.p_d0s_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_d1s_grid_,
|
||||
arg.p_e1_grid_,
|
||||
arg.a0_element_op_,
|
||||
arg.b0_element_op_,
|
||||
arg.cde0_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.cde1_element_op_,
|
||||
arg.a0_grid_desc_ak0_m_ak1_,
|
||||
arg.b0_grid_desc_bk0_n_bk1_,
|
||||
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_e1tile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
// to concern Gemm0's loop
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static bool CheckDLayout()
|
||||
{
|
||||
static bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check supported layouts
|
||||
// A0 - Row
|
||||
// B0 - Col
|
||||
// D0s - Rows
|
||||
// B1 - Row or Col
|
||||
// D1s - Rows
|
||||
// E1 - Row
|
||||
if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> &&
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
|
||||
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor,
|
||||
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
|
||||
D1sLayout,
|
||||
NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
|
||||
arg.b0_grid_desc_n_k_,
|
||||
arg.b1_grid_desc_n_k_,
|
||||
arg.e1_grid_desc_m_n_,
|
||||
arg.block_2_e1tile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const A0DataType* p_a0,
|
||||
const B0DataType* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const B1DataType* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
E1DataType* p_e1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
{
|
||||
return Argument{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a0,
|
||||
const void* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const void* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
void* p_e1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const A0DataType*>(p_a0),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_e1),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< Gemm0MPerBlock << ", "
|
||||
<< Gemm0NPerBlock << ", "
|
||||
<< Gemm0KPerBlock << ", "
|
||||
<< A0K1 << ", "
|
||||
<< B0K1 << ", "
|
||||
<< B1K1 << ", "
|
||||
<< Gemm0MPerXdl << ", "
|
||||
<< Gemm0NPerXdl << ", "
|
||||
<< Gemm0MXdlPerWave << ", "
|
||||
<< Gemm0NXdlPerWave << ", "
|
||||
<< Gemm1NXdlPerWave << "> ";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,955 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename D0sPointer,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const C0DEElementwiseOperation c0de_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const C1DEElementwiseOperation c1de_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0s_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c1de_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_b1_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0s_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c0de_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = c1de_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO, // NumDimGemm1N
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization B1Spec,
|
||||
TensorSpecialization CSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock, // Gemm0NPerBlock
|
||||
index_t KPerBlock, // Gemm0KPerBlock
|
||||
index_t Gemm1NPerBlock,
|
||||
index_t Gemm1KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t B1K1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
index_t Gemm1NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
index_t B1BlockTransferSrcVectorDim,
|
||||
index_t B1BlockTransferSrcScalarPerVector,
|
||||
index_t B1BlockTransferDstScalarPerVector_BK1,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
int D0sTransferSrcScalarPerVector = 4,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
static constexpr index_t NumDimGemm0M = NumDimM;
|
||||
static constexpr index_t NumDimGemm0N = NumDimN;
|
||||
static constexpr index_t NumDimGemm0K = NumDimK;
|
||||
static constexpr index_t NumDimGemm1M = NumDimM;
|
||||
static constexpr index_t NumDimGemm1N = NumDimO;
|
||||
static constexpr index_t NumDimGemm1K = NumDimN;
|
||||
#endif
|
||||
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
|
||||
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
|
||||
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
B1Spec,
|
||||
CSpec>;
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
b1_gs_gemm1ns_gemm1ks_strides_vec),
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_G_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
|
||||
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
return MaskDisabledPredicate{};
|
||||
}
|
||||
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
return MaskOutUpperTrianglePredicate{};
|
||||
}
|
||||
}
|
||||
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
|
||||
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
|
||||
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d0_idx) const
|
||||
{
|
||||
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
D0sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
C1GridDesc_M_N,
|
||||
D0sGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
|
||||
D0sTransferSrcScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(
|
||||
const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0s_grid_{},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_g_m_k_{
|
||||
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_g_n_k_{
|
||||
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c0de_element_op_{c0de_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c1de_element_op_{c1de_element_op},
|
||||
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
|
||||
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
|
||||
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
|
||||
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
|
||||
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
|
||||
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
|
||||
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_g_m_k_,
|
||||
b_grid_desc_g_n_k_,
|
||||
b1_grid_desc_g_n_k_,
|
||||
c1_grid_desc_g_m_n_,
|
||||
d0s_grid_desc_g_m_n_}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
ignore = p_acc1_biases;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_strides;
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
|
||||
// for check
|
||||
d0s_nl_ns_lengths_strides_[i].push_back(
|
||||
acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]);
|
||||
d0s_nl_ns_lengths_strides_[i].push_back(
|
||||
acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c1_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
|
||||
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
d0s_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
|
||||
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
|
||||
|
||||
// tensor descriptor
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
|
||||
|
||||
// block-to-c-tile map
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
C0DEElementwiseOperation c0de_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
C1DEElementwiseOperation c1de_element_op_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
|
||||
std::vector<index_t> a_mz_kz_strides_;
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
std::array<std::vector<ck::index_t>, NumD0Tensor> d0s_nl_ns_lengths_strides_;
|
||||
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!DeviceOp::IsSupportedArgument(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported argument");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
typename GridwiseGemm::D0sGridPointer,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0s_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c0de_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.c1de_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.c0_matrix_mask_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
// to concern Gemm0's loop
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
// Check if C permute dimension matches GEMM + GEMM shape
|
||||
const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
|
||||
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
|
||||
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
|
||||
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
|
||||
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
|
||||
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
|
||||
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
|
||||
const auto c_extent_lowest = Gemm1NzRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest =
|
||||
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
|
||||
const auto c_stride_lowest =
|
||||
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
|
||||
|
||||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(int i = 0; i < NumD0Tensor; i++)
|
||||
{
|
||||
if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 &&
|
||||
arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c1_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(
|
||||
const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_b1,
|
||||
p_c,
|
||||
p_acc0_biases,
|
||||
p_acc1_biases,
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c1de_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
// FIXME: constness
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
C1DEElementwiseOperation c1de_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
p_acc0_biases, // cast in struct Argument
|
||||
p_acc1_biases, // cast in struct Argument
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths,
|
||||
acc1_biases_gs_ms_gemm1ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c1de_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
|
||||
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
|
||||
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
|
||||
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
|
||||
<< getMaskingSpecializationString(MaskingSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,434 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
|
||||
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
|
||||
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const auto a_grid_desc_k0_m_k1 =
|
||||
amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
|
||||
karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
|
||||
const auto b_grid_desc_k0_n_k1 =
|
||||
amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
|
||||
karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
|
||||
const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
|
||||
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_m_n);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
ck::index_t NumGemmKPrefetchStage = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpecialization::MNKPadding,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
|
||||
// Argument
|
||||
struct Argument : public Problem, public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
Batch(Batch_),
|
||||
compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC}
|
||||
{
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceBatchedGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
karg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
|
||||
}
|
||||
|
||||
auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
|
||||
gdx *= karg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Problem& problem)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(problem);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmXdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ">"
|
||||
<< " NumGemmKPrefetchStage: "
|
||||
<< NumGemmKPrefetchStage << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,876 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim,
|
||||
bool UseMultiblockInK,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XDyDxVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t DscaleDbiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
MThreadSliceSize % DxDstVectorSize == 0) ||
|
||||
(XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DySrcVectorSize == 0 &&
|
||||
KThreadSliceSize % DxDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
|
||||
const std::array<index_t, Rank>& xyStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleXYLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleXYStrides =
|
||||
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
|
||||
|
||||
const auto grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
|
||||
Number<NumBatchNormReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(raw_grid_desc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_g_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_g,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_g_padded);
|
||||
};
|
||||
|
||||
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad =
|
||||
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
|
||||
const std::array<index_t, NumInvariantDim>& strides)
|
||||
{
|
||||
const auto tupleLengths =
|
||||
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleStrides =
|
||||
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_m = transform_tensor_descriptor(
|
||||
raw_grid_desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_padded =
|
||||
transform_tensor_descriptor(grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
|
||||
using ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
|
||||
using MeanVarGridDesc_M = ScaleBiasGridDesc_M;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> dyStrides,
|
||||
const std::array<index_t, Rank> dxStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const DyDataType* p_dy,
|
||||
const ScaleDataType* p_scale,
|
||||
const MeanVarDataType* p_savedMean,
|
||||
const MeanVarDataType* p_savedInvVar,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
double epsilon,
|
||||
DxDataType* p_dx,
|
||||
DscaleDbiasDataType* p_dscale,
|
||||
DscaleDbiasDataType* p_dbias)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_dy_(p_dy),
|
||||
p_scale_(p_scale),
|
||||
p_savedMean_(p_savedMean),
|
||||
p_savedInvVar_(p_savedInvVar),
|
||||
dy_elementwise_op_(dy_elementwise_op),
|
||||
p_dx_(p_dx),
|
||||
p_dscale_(p_dscale),
|
||||
p_dbias_(p_dbias)
|
||||
{
|
||||
xyLengths_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
|
||||
xStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
|
||||
dyStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dyStrides, reduceDims);
|
||||
dxStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dxStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_length, reduce_length) =
|
||||
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
|
||||
haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr);
|
||||
|
||||
if(UseMultiblockInK)
|
||||
{
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize = 1;
|
||||
numBlockTileIteration = (reduce_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
gridSize = (invariant_length + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize;
|
||||
|
||||
x_grid_desc_m_k =
|
||||
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize, numBlockTileIteration);
|
||||
dy_grid_desc_m_k =
|
||||
MakeXY2dDescriptor(xyLengths_, dyStrides_, blkGroupSize, numBlockTileIteration);
|
||||
dx_grid_desc_m_k =
|
||||
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
|
||||
scale_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
|
||||
dscale_dbias_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
|
||||
mean_var_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
|
||||
bool haveSavedMeanInvVar_;
|
||||
|
||||
std::array<index_t, Rank> xyLengths_;
|
||||
std::array<index_t, Rank> xStrides_;
|
||||
std::array<index_t, Rank> dyStrides_;
|
||||
std::array<index_t, Rank> dxStrides_;
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const DyDataType* p_dy_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const MeanVarDataType* p_savedMean_;
|
||||
const MeanVarDataType* p_savedInvVar_;
|
||||
const DyElementwiseOp dy_elementwise_op_;
|
||||
DxDataType* p_dx_;
|
||||
DscaleDbiasDataType* p_dscale_;
|
||||
DscaleDbiasDataType* p_dbias_;
|
||||
|
||||
long_index_t invariant_length;
|
||||
long_index_t reduce_length;
|
||||
|
||||
int blkGroupSize;
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
|
||||
XYGridDesc_M_K x_grid_desc_m_k;
|
||||
XYGridDesc_M_K dy_grid_desc_m_k;
|
||||
XYGridDesc_M_K dx_grid_desc_m_k;
|
||||
ScaleBiasGridDesc_M scale_grid_desc_m;
|
||||
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
|
||||
MeanVarGridDesc_M mean_var_grid_desc_m;
|
||||
|
||||
void* workspace_mean;
|
||||
void* workspace_variance;
|
||||
void* workspace_count;
|
||||
|
||||
void* workspace_savedMean;
|
||||
void* workspace_savedInvVar;
|
||||
|
||||
void* workspace_reduce_dscale;
|
||||
void* workspace_reduce_dbias;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
|
||||
{
|
||||
// workspace for the partial reduced result for dscale
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
// workspace for the partial reduced result for dbias
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
if(!pArg_->haveSavedMeanInvVar_)
|
||||
{
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64;
|
||||
|
||||
// workspace for welford result mean
|
||||
workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford result inv_variance
|
||||
workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
|
||||
};
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
index_t space_sz;
|
||||
|
||||
// setup buffer for the partial reduced result for dscale
|
||||
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for the partial reduced result for dbias
|
||||
pArg_->workspace_reduce_dbias =
|
||||
reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
|
||||
{
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford intermediate mean
|
||||
pArg_->workspace_mean =
|
||||
reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford intermediate varirance
|
||||
pArg_->workspace_variance = reinterpret_cast<char*>(pArg_->workspace_mean) + space_sz;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford intermediate count
|
||||
pArg_->workspace_count = reinterpret_cast<char*>(pArg_->workspace_variance) + space_sz;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford result mean
|
||||
pArg_->workspace_savedMean = reinterpret_cast<char*>(pArg_->workspace_count) + space_sz;
|
||||
|
||||
space_sz = pArg_->invariant_length * sizeof(MeanVarDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford result inv_variance
|
||||
pArg_->workspace_savedInvVar =
|
||||
reinterpret_cast<char*>(pArg_->workspace_savedMean) + space_sz;
|
||||
};
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0;
|
||||
|
||||
const auto mean_var_count_grid_desc_m_g =
|
||||
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
|
||||
arg.invariant_length, arg.blkGroupSize);
|
||||
|
||||
const auto dscale_dbias_grid_desc_m_g =
|
||||
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
|
||||
arg.invariant_length, arg.blkGroupSize);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_k =
|
||||
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
|
||||
arg.invariant_length, arg.blkGroupSize);
|
||||
|
||||
const auto dscale_dbias_grid_desc_m_k =
|
||||
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
|
||||
arg.invariant_length, arg.blkGroupSize);
|
||||
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
|
||||
using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
|
||||
|
||||
using GridwiseWelfordSecondHalfReduceFirstHalf_ =
|
||||
GridwiseWelfordSecondHalfReduceFirstHalf<XDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarGridDesc_M,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
DscaleDbiasGridDesc_M_G,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
using GridwiseReduceSecondHalfBatchNormBwdFinal_ =
|
||||
GridwiseReduceSecondHalfBatchNormBackwardFinal<XDataType,
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
MeanVarGridDesc_M,
|
||||
ScaleBiasGridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize > 1)
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.blkGroupSize, arg.numBlockTileIteration, arg.reduce_length);
|
||||
|
||||
if(!arg.haveSavedMeanInvVar_)
|
||||
{
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize>;
|
||||
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance),
|
||||
static_cast<int32_t*>(arg.workspace_count));
|
||||
};
|
||||
|
||||
const auto kern_welford_second_half_reduce_first_half =
|
||||
kernel_welford_second_half_reduce_first_half<
|
||||
GridwiseWelfordSecondHalfReduceFirstHalf_,
|
||||
XDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarGridDesc_M,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
DscaleDbiasGridDesc_M_G>;
|
||||
|
||||
const auto kern_reduce_second_half_batchnorm_backward_final =
|
||||
kernel_reduce_second_half_batchnorm_backward_final<
|
||||
GridwiseReduceSecondHalfBatchNormBwdFinal_,
|
||||
XDataType,
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
MeanVarGridDesc_M,
|
||||
ScaleBiasGridDesc_M>;
|
||||
|
||||
index_t numDscaleDbiasBlockTileIteration =
|
||||
(arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_welford_second_half_reduce_first_half,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k,
|
||||
arg.dy_grid_desc_m_k,
|
||||
arg.mean_var_grid_desc_m,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
arg.blkGroupSize,
|
||||
arg.numBlockTileIteration,
|
||||
numDscaleDbiasBlockTileIteration,
|
||||
arg.epsilon_,
|
||||
arg.haveSavedMeanInvVar_,
|
||||
arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
|
||||
arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr,
|
||||
arg.haveSavedMeanInvVar_
|
||||
? nullptr
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_mean),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? nullptr
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_variance),
|
||||
arg.haveSavedMeanInvVar_ ? nullptr
|
||||
: static_cast<const int32_t*>(arg.workspace_count),
|
||||
arg.dy_elementwise_op_,
|
||||
arg.haveSavedMeanInvVar_
|
||||
? nullptr
|
||||
: static_cast<MeanVarDataType*>(arg.workspace_savedMean),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? nullptr
|
||||
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
|
||||
arg.p_x_,
|
||||
arg.p_dy_,
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_reduce_second_half_batchnorm_backward_final,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k,
|
||||
arg.dy_grid_desc_m_k,
|
||||
arg.dx_grid_desc_m_k,
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
arg.mean_var_grid_desc_m,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.blkGroupSize,
|
||||
arg.reduce_length,
|
||||
arg.numBlockTileIteration,
|
||||
numDscaleDbiasBlockTileIteration,
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? arg.p_savedMean_
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? arg.p_savedInvVar_
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_savedInvVar),
|
||||
arg.p_x_,
|
||||
arg.p_dy_,
|
||||
arg.p_scale_,
|
||||
arg.dy_elementwise_op_,
|
||||
arg.p_dx_,
|
||||
arg.p_dscale_,
|
||||
arg.p_dbias_);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.numBlockTileIteration, arg.reduce_length);
|
||||
|
||||
using GridwiseBatchNormBackwardWithBlockwiseWelford_ =
|
||||
GridwiseBatchNormBackwardWithBlockwiseWelford<XDataType,
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasGridDesc_M,
|
||||
MeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XDyDxVectorDim,
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
|
||||
GridwiseBatchNormBackwardWithBlockwiseWelford_,
|
||||
XDataType,
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasGridDesc_M,
|
||||
MeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kern_batchnorm_bwd,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k,
|
||||
arg.dy_grid_desc_m_k,
|
||||
arg.dx_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
arg.reduce_length,
|
||||
arg.numBlockTileIteration,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
arg.p_dy_,
|
||||
arg.p_scale_,
|
||||
arg.haveSavedMeanInvVar_,
|
||||
arg.p_savedMean_,
|
||||
arg.p_savedInvVar_,
|
||||
arg.dy_elementwise_op_,
|
||||
arg.p_dx_,
|
||||
arg.p_dscale_,
|
||||
arg.p_dbias_);
|
||||
};
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* pArg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* pArg) override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
if constexpr(XDyDxVectorDim == 0)
|
||||
{
|
||||
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->dyStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->dxStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 ||
|
||||
pArg_->dxStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->haveSavedMeanInvVar_)
|
||||
{
|
||||
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_valid = true;
|
||||
|
||||
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
|
||||
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
|
||||
is_valid = false;
|
||||
});
|
||||
|
||||
if(!is_valid)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> dyStrides,
|
||||
const std::array<index_t, Rank> dxStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
const void* p_scale,
|
||||
const void* p_savedMean,
|
||||
const void* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
void* p_dx,
|
||||
void* p_dscale,
|
||||
void* p_dbias) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
dyStrides,
|
||||
dxStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnDscaleDbiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const DyDataType*>(p_dy),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const MeanVarDataType*>(p_savedMean),
|
||||
static_cast<const MeanVarDataType*>(p_savedInvVar),
|
||||
dy_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<DxDataType*>(p_dx),
|
||||
static_cast<DscaleDbiasDataType*>(p_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(p_dbias));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchNormBwdImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,824 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim,
|
||||
bool UseMultiblockInK,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
|
||||
const std::array<index_t, Rank>& xyStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleXYLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleXYStrides =
|
||||
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
|
||||
|
||||
const auto grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
|
||||
Number<NumBatchNormReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(raw_grid_desc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_g_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_g,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_g_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad =
|
||||
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
|
||||
const std::array<index_t, NumInvariantDim>& strides)
|
||||
{
|
||||
const auto tupleLengths =
|
||||
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleStrides =
|
||||
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_m = transform_tensor_descriptor(
|
||||
raw_grid_desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_padded =
|
||||
transform_tensor_descriptor(grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
|
||||
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* p_scale,
|
||||
const BiasDataType* p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
double epsilon,
|
||||
YDataType* p_y,
|
||||
MeanVarDataType* resultSaveMean,
|
||||
MeanVarDataType* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
MeanVarDataType* resultRunningMean,
|
||||
MeanVarDataType* resultRunningVariance)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_scale_(p_scale),
|
||||
p_bias_(p_bias),
|
||||
y_elementwise_op_(y_elementwise_op),
|
||||
p_y_(p_y),
|
||||
resultSaveMean_(resultSaveMean),
|
||||
resultSaveInvVariance_(resultSaveInvVariance),
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance)
|
||||
{
|
||||
xyLengths_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
|
||||
xStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
|
||||
yStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_length_, reduce_length_) =
|
||||
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
averageFactor_ = type_convert<AccDataType>(averageFactor);
|
||||
|
||||
updateMovingAverage_ =
|
||||
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
|
||||
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
|
||||
|
||||
if(UseMultiblockInK)
|
||||
{
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 16
|
||||
if(testBlkGroupSize <= 16)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration_ = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize_ = 1;
|
||||
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
|
||||
|
||||
x_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
y_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
scale_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
|
||||
bias_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
|
||||
mean_var_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
AccDataType averageFactor_;
|
||||
|
||||
bool updateMovingAverage_;
|
||||
bool saveMeanInvVariance_;
|
||||
|
||||
std::array<index_t, Rank> xyLengths_;
|
||||
std::array<index_t, Rank> xStrides_;
|
||||
std::array<index_t, Rank> yStrides_;
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const BiasDataType* p_bias_;
|
||||
const YElementwiseOp y_elementwise_op_;
|
||||
YDataType* p_y_;
|
||||
|
||||
MeanVarDataType* resultSaveMean_;
|
||||
MeanVarDataType* resultSaveInvVariance_;
|
||||
|
||||
MeanVarDataType* resultRunningMean_;
|
||||
MeanVarDataType* resultRunningVariance_;
|
||||
|
||||
long_index_t invariant_length_;
|
||||
long_index_t reduce_length_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
size_t gridSize_;
|
||||
|
||||
XYGridDesc_M_K x_grid_desc_m_k_;
|
||||
XYGridDesc_M_K y_grid_desc_m_k_;
|
||||
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
|
||||
|
||||
void* workspace_mean_;
|
||||
void* workspace_variance_;
|
||||
void* workspace_count_;
|
||||
|
||||
void* control_;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
|
||||
|
||||
// workspace for barrier objects, each barrier object consists of two integers
|
||||
// TODO: allocate barrier object memory globally to reuse it by other operators
|
||||
workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
|
||||
sizeof(int) * 2;
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welford varirance
|
||||
pArg_->workspace_variance_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welfor count
|
||||
pArg_->workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
|
||||
|
||||
index_t count_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
|
||||
|
||||
count_space_sz = math::integer_least_multiple(count_space_sz, 64);
|
||||
|
||||
pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
|
||||
|
||||
index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
|
||||
M_BlockTileSize * sizeof(int) * 2;
|
||||
|
||||
hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
|
||||
};
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_g =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_k =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
|
||||
using GridwiseMultiblockBatchNormForward_ =
|
||||
GridwiseMultiblockBatchNormForward<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize>;
|
||||
|
||||
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
// It is found that:
|
||||
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
|
||||
// two-kernel method for gfx1030
|
||||
// 2) Profiler on gfx908 could hang even though it works when running examples
|
||||
// 3) Single-kernel method works on gfx1100, but the performance it not better
|
||||
// than two-kernel method (due to more warps participating the barrier)
|
||||
if(ck::get_device_name() == "gfx90a")
|
||||
{
|
||||
const auto kern_multiblock_batchnorm_fwd_ =
|
||||
kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_multiblock_batchnorm_fwd_,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
|
||||
// workspace by multiple workgroups
|
||||
mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
|
||||
// workspace by each workgroup
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
static_cast<int*>(arg.control_),
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kern_batchnorm_fwd,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* pArg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* pArg) override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
if constexpr(XSrcYDstVectorDim == 0)
|
||||
{
|
||||
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
bool is_valid = true;
|
||||
|
||||
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
|
||||
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
|
||||
is_valid = false;
|
||||
});
|
||||
|
||||
if(!is_valid)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<YDataType*>(p_y),
|
||||
static_cast<MeanVarDataType*>(resultSaveMean),
|
||||
static_cast<MeanVarDataType*>(resultSaveInvVariance),
|
||||
averageFactor,
|
||||
static_cast<MeanVarDataType*>(resultRunningMean),
|
||||
static_cast<MeanVarDataType*>(resultRunningVariance));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,716 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim,
|
||||
bool UseMultiblockInK,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
|
||||
const std::array<index_t, Rank>& xyStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleXYLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleXYStrides =
|
||||
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
|
||||
|
||||
const auto grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
|
||||
Number<NumBatchNormReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(raw_grid_desc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_g_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_g,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_g_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k = make_naive_tensor_descriptor(
|
||||
make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad =
|
||||
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
|
||||
const std::array<index_t, NumInvariantDim>& strides)
|
||||
{
|
||||
const auto tupleLengths =
|
||||
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleStrides =
|
||||
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_m = transform_tensor_descriptor(
|
||||
raw_grid_desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_padded =
|
||||
transform_tensor_descriptor(grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
|
||||
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* p_scale,
|
||||
const BiasDataType* p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
double epsilon,
|
||||
YDataType* p_y,
|
||||
MeanVarDataType* resultSaveMean,
|
||||
MeanVarDataType* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
MeanVarDataType* resultRunningMean,
|
||||
MeanVarDataType* resultRunningVariance)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_scale_(p_scale),
|
||||
p_bias_(p_bias),
|
||||
y_elementwise_op_(y_elementwise_op),
|
||||
p_y_(p_y),
|
||||
resultSaveMean_(resultSaveMean),
|
||||
resultSaveInvVariance_(resultSaveInvVariance),
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance)
|
||||
{
|
||||
xyLengths_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
|
||||
xStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
|
||||
yStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_length_, reduce_length_) =
|
||||
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
averageFactor_ = type_convert<AccDataType>(averageFactor);
|
||||
|
||||
updateMovingAverage_ =
|
||||
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
|
||||
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
|
||||
|
||||
if(UseMultiblockInK)
|
||||
{
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 16
|
||||
if(testBlkGroupSize <= 16)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration_ = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize_ = 1;
|
||||
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
|
||||
|
||||
x_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
y_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
scale_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
|
||||
bias_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
|
||||
mean_var_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
AccDataType averageFactor_;
|
||||
|
||||
bool updateMovingAverage_;
|
||||
bool saveMeanInvVariance_;
|
||||
|
||||
std::array<index_t, Rank> xyLengths_;
|
||||
std::array<index_t, Rank> xStrides_;
|
||||
std::array<index_t, Rank> yStrides_;
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const BiasDataType* p_bias_;
|
||||
const YElementwiseOp y_elementwise_op_;
|
||||
YDataType* p_y_;
|
||||
|
||||
MeanVarDataType* resultSaveMean_;
|
||||
MeanVarDataType* resultSaveInvVariance_;
|
||||
|
||||
MeanVarDataType* resultRunningMean_;
|
||||
MeanVarDataType* resultRunningVariance_;
|
||||
|
||||
long_index_t invariant_length_;
|
||||
long_index_t reduce_length_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
size_t gridSize_;
|
||||
|
||||
XYGridDesc_M_K x_grid_desc_m_k_;
|
||||
XYGridDesc_M_K y_grid_desc_m_k_;
|
||||
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
|
||||
|
||||
void* workspace_mean_;
|
||||
void* workspace_variance_;
|
||||
void* workspace_count_;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welford varirance
|
||||
pArg_->workspace_variance_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welfor count
|
||||
pArg_->workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
|
||||
};
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_g =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_k =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize>;
|
||||
|
||||
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kern_batchnorm_fwd,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* pArg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* pArg) override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
if constexpr(XSrcYDstVectorDim == 0)
|
||||
{
|
||||
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
bool is_valid = true;
|
||||
|
||||
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
|
||||
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
|
||||
is_valid = false;
|
||||
});
|
||||
|
||||
if(!is_valid)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<YDataType*>(p_y),
|
||||
static_cast<MeanVarDataType*>(resultSaveMean),
|
||||
static_cast<MeanVarDataType*>(resultSaveInvVariance),
|
||||
averageFactor,
|
||||
static_cast<MeanVarDataType*>(resultRunningMean),
|
||||
static_cast<MeanVarDataType*>(resultRunningVariance));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,625 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t MPerThread =
|
||||
MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t NPerThread =
|
||||
NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto AScalarPerVector = Number<4>{};
|
||||
static constexpr auto BScalarPerVector = Number<4>{};
|
||||
static constexpr auto CScalarPerVector = Number<4>{};
|
||||
|
||||
template <typename Desc_M_N>
|
||||
static auto PadDescriptor_M_N(Desc_M_N desc)
|
||||
{
|
||||
const auto M = desc.GetLength(I0);
|
||||
const auto N = desc.GetLength(I1);
|
||||
const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
|
||||
const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return padded_desc;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
return PadDescriptor_M_N(desc);
|
||||
}
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
|
||||
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
|
||||
{
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
|
||||
Argument(const ADataType* p_a_grid_real_,
|
||||
const ADataType* p_a_grid_imag_,
|
||||
const BDataType* p_b_grid_real_,
|
||||
const BDataType* p_b_grid_imag_,
|
||||
CDataType* p_c_grid_real_,
|
||||
CDataType* p_c_grid_imag_,
|
||||
CDataType* p_workspace,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
|
||||
p_a_grid_real{p_a_grid_real_},
|
||||
p_a_grid_imag{p_a_grid_imag_},
|
||||
p_b_grid_real{p_b_grid_real_},
|
||||
p_b_grid_imag{p_b_grid_imag_},
|
||||
p_c_grid_real{p_c_grid_real_},
|
||||
p_c_grid_imag{p_c_grid_imag_},
|
||||
p_aux_grid{p_workspace}
|
||||
{
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
|
||||
}
|
||||
|
||||
p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_real;
|
||||
const ADataType* p_a_grid_imag;
|
||||
const BDataType* p_b_grid_real;
|
||||
const BDataType* p_b_grid_imag;
|
||||
CDataType* p_c_grid_real;
|
||||
CDataType* p_c_grid_imag;
|
||||
CDataType* p_aux_grid;
|
||||
CDataType* p_aux_2_grid;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseBinAdd = GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Add,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
Sequence<0, 1>,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
using GridwiseBinSubtract =
|
||||
GridwiseElementwise<Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Subtract,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
Sequence<0, 1>,
|
||||
Sequence<AScalarPerVector, BScalarPerVector>,
|
||||
Sequence<CScalarPerVector>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
|
||||
const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
|
||||
const auto block_2_tile_map = Block2TileMap(M, N);
|
||||
|
||||
const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
|
||||
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Add>;
|
||||
|
||||
const auto subtract_kernel =
|
||||
kernel_elementwise<GridwiseBinSubtract,
|
||||
Tuple<CGridDesc_M_N, CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const CDataType*, const CDataType*>,
|
||||
Tuple<CDataType*>,
|
||||
Block2TileMap,
|
||||
Subtract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real,
|
||||
arg.p_b_grid_real,
|
||||
arg.p_aux_grid,
|
||||
arg);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag,
|
||||
arg.p_b_grid_imag,
|
||||
arg.p_aux_2_grid,
|
||||
arg);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
subtract_kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_real),
|
||||
block_2_tile_map,
|
||||
Subtract{});
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real,
|
||||
arg.p_b_grid_imag,
|
||||
arg.p_aux_grid,
|
||||
arg);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag,
|
||||
arg.p_b_grid_real,
|
||||
arg.p_aux_2_grid,
|
||||
arg);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
add_kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_imag),
|
||||
block_2_tile_map,
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real,
|
||||
arg.p_b_grid_real,
|
||||
arg.p_aux_grid,
|
||||
arg);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag,
|
||||
arg.p_b_grid_imag,
|
||||
arg.p_aux_2_grid,
|
||||
arg);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
subtract_kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_real),
|
||||
block_2_tile_map,
|
||||
Subtract{});
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real,
|
||||
arg.p_b_grid_imag,
|
||||
arg.p_aux_grid,
|
||||
arg);
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag,
|
||||
arg.p_b_grid_real,
|
||||
arg.p_aux_2_grid,
|
||||
arg);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
add_kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n),
|
||||
make_tuple(arg.c_grid_desc_m_n),
|
||||
make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
|
||||
const_cast<const CDataType*>(arg.p_aux_2_grid)),
|
||||
make_tuple(arg.p_c_grid_imag),
|
||||
block_2_tile_map,
|
||||
Add{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a_real,
|
||||
const ADataType* p_a_imag,
|
||||
const BDataType* p_b_real,
|
||||
const BDataType* p_b_imag,
|
||||
CDataType* p_c_real,
|
||||
CDataType* p_c_imag,
|
||||
CDataType* p_workspace,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a_real,
|
||||
p_a_imag,
|
||||
p_b_real,
|
||||
p_b_imag,
|
||||
p_c_real,
|
||||
p_c_imag,
|
||||
p_workspace,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
|
||||
static_cast<const ADataType*>(p_a_imag),
|
||||
static_cast<const BDataType*>(p_b_real),
|
||||
static_cast<const BDataType*>(p_b_imag),
|
||||
static_cast<CDataType*>(p_c_real),
|
||||
static_cast<CDataType*>(p_c_imag),
|
||||
static_cast<CDataType*>(p_workspace),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
|
||||
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
|
||||
|
||||
return c_grid_desc_m_n.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(index_t M,
|
||||
index_t N,
|
||||
[[maybe_unused]] index_t K,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) const override
|
||||
{
|
||||
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
|
||||
}
|
||||
|
||||
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
|
||||
{
|
||||
const auto* parg = dynamic_cast<const Argument*>(base_arg);
|
||||
|
||||
if(!parg)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Provided argument pointer is not of an Argument class!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return GetWorkspaceSize(
|
||||
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,647 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Column to Image:
|
||||
// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// output : input image [G, N, Di, Hi, Wi, C]
|
||||
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
|
||||
// output : input image [N, Di, Hi, Wi, G, C]
|
||||
template <index_t NDimSpatial,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
typename OutputDataType,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t KPerBlock,
|
||||
typename ThreadClusterLengths,
|
||||
index_t ScalarPerVector,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct DeviceColumnToImageImpl
|
||||
: public DeviceConvTensorRearrange<NDimSpatial,
|
||||
ImageLayout,
|
||||
InputDataType,
|
||||
OutputDataType,
|
||||
conv_tensor_rearrange_op::ColumnToImage>
|
||||
{
|
||||
static constexpr bool is_NSpatialGC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
|
||||
static constexpr bool is_GNSpatialC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto ZIdx = Number<I0>{};
|
||||
static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number<NDimSpatial - I2>{};
|
||||
static constexpr auto XIdx = Number<NDimSpatial - I1>{};
|
||||
|
||||
static constexpr auto spatial_offset = Number<3>{};
|
||||
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
|
||||
MPerBlock, 0 /* NPerBlock*/, KPerBlock};
|
||||
|
||||
// Calculate number of independent filters for given conv params
|
||||
static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len,
|
||||
const index_t left_pad,
|
||||
const index_t right_pad,
|
||||
const index_t filter_len,
|
||||
const index_t filter_stride,
|
||||
const index_t filter_dilation,
|
||||
const index_t image_offset)
|
||||
{
|
||||
const index_t x_eff = (filter_len - 1) * filter_dilation + 1;
|
||||
const index_t next_filter_padded =
|
||||
math::integer_divide_ceil(x_eff, filter_stride) * filter_stride;
|
||||
// If filter_stride >= x_eff then each filter is independent
|
||||
const index_t independent_filter_stride =
|
||||
filter_stride >= x_eff ? filter_stride : next_filter_padded;
|
||||
const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff;
|
||||
// There are no independent filters
|
||||
if(w_eff < 0)
|
||||
return 0;
|
||||
const index_t independent_kernels_num = w_eff / independent_filter_stride + 1;
|
||||
return independent_kernels_num;
|
||||
}
|
||||
|
||||
// Make column form descriptor
|
||||
static auto
|
||||
MakeInputDescriptor_M_K(const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& independent_filters,
|
||||
const std::array<index_t, NDimSpatial>& effs)
|
||||
{
|
||||
const index_t DoHoWo = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t CZYX =
|
||||
C * ck::accumulate_n<index_t>(
|
||||
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
|
||||
// Calculate the appropriate stride for each set of independent filters
|
||||
// in each dimension
|
||||
const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
|
||||
gemm_g_m_k_strides[I1];
|
||||
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
|
||||
output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
|
||||
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
|
||||
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
|
||||
gemm_g_m_k_strides[I1];
|
||||
// Create descriptor for independent filters in each dimension and
|
||||
// then merge them into column form
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const auto desc_gemm_form =
|
||||
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const auto desc_gemm_form = make_naive_tensor_descriptor(
|
||||
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const auto desc_gemm_form = make_naive_tensor_descriptor(
|
||||
make_tuple(N,
|
||||
independent_filters[ZIdx],
|
||||
independent_filters[YIdx],
|
||||
independent_filters[XIdx],
|
||||
CZYX),
|
||||
make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N,
|
||||
independent_filters[ZIdx],
|
||||
independent_filters[YIdx],
|
||||
independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
// Use MakeADescriptor_M_K from grouped convolution forward
|
||||
static auto
|
||||
MakeOutDescriptor_M_K(const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const std::array<index_t, NDimSpatial>& image_offsets,
|
||||
const std::array<index_t, NDimSpatial>& independent_filters,
|
||||
const std::array<index_t, NDimSpatial>& effs)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{1};
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{1};
|
||||
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{1};
|
||||
|
||||
auto copy = [](const auto& x, auto& y, index_t dst_offset) {
|
||||
std::copy(x.begin(), x.end(), y.begin() + dst_offset);
|
||||
};
|
||||
|
||||
copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset);
|
||||
copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset);
|
||||
// Calculate descriptor only for independent filters
|
||||
copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset);
|
||||
|
||||
// fill only significant values (C and N)
|
||||
a_g_n_c_wis_lengths[I1] = N;
|
||||
a_g_n_c_wis_lengths[I2] = C;
|
||||
b_g_k_c_xs_lengths[I2] = C;
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
// Modify pads to apply offsets
|
||||
std::array<index_t, NDimSpatial> input_left_pads_with_offset;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]);
|
||||
}
|
||||
// Modify input spatial lengths to apply offsets
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
a_g_n_c_wis_lengths[i + spatial_offset] -=
|
||||
math::max(0, image_offsets[i] - input_left_pads[i]);
|
||||
}
|
||||
|
||||
// Strides to next independent filters
|
||||
std::array<index_t, NDimSpatial> independent_filter_strides;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
index_t independent_filter_stride =
|
||||
math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i];
|
||||
// If conv stride is greater than whole filter size, use conv stride
|
||||
independent_filter_strides[i] = conv_filter_strides[i] >= effs[i]
|
||||
? conv_filter_strides[i]
|
||||
: independent_filter_stride;
|
||||
}
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
c_g_n_k_wos_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
// conv_filter_strides,
|
||||
independent_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads_with_offset,
|
||||
input_right_pads};
|
||||
|
||||
// Calculate image form descriptor for the modified convolution problem
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
using InputGridDesc =
|
||||
remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}))>;
|
||||
using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(
|
||||
1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
|
||||
using Block2ETileMap = remove_cvref_t<
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
InputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
: G_(G),
|
||||
C_(C),
|
||||
X_(filter_spatial_lengths[NDimSpatial - I1]),
|
||||
p_in_{static_cast<const InputDataType*>(p_in)},
|
||||
p_out_{static_cast<OutputDataType*>(p_out)},
|
||||
image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
|
||||
|
||||
const index_t x_eff =
|
||||
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
|
||||
const index_t y_eff =
|
||||
NDimSpatial < 2
|
||||
? I1
|
||||
: (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1;
|
||||
const index_t z_eff =
|
||||
NDimSpatial < 3
|
||||
? I1
|
||||
: (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1;
|
||||
|
||||
// Iterate over sets of independent filters
|
||||
for(int z_img_offset = 0; z_img_offset < z_eff;
|
||||
z_img_offset += conv_filter_strides[ZIdx])
|
||||
{
|
||||
for(int y_img_offset = 0; y_img_offset < y_eff;
|
||||
y_img_offset += conv_filter_strides[YIdx])
|
||||
{
|
||||
for(int x_img_offset = 0; x_img_offset < x_eff;
|
||||
x_img_offset += conv_filter_strides[XIdx])
|
||||
{
|
||||
|
||||
std::array<index_t, NDimSpatial> image_offsets;
|
||||
std::array<index_t, NDimSpatial> effs;
|
||||
// Calculate the starting offset for a given set of
|
||||
// independent filters
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
image_offsets = {x_img_offset};
|
||||
effs = {x_eff};
|
||||
}
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
image_offsets = {y_img_offset, x_img_offset};
|
||||
effs = {y_eff, x_eff};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
image_offsets = {z_img_offset, y_img_offset, x_img_offset};
|
||||
effs = {z_eff, y_eff, x_eff};
|
||||
}
|
||||
|
||||
std::array<index_t, NDimSpatial> independent_filters;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
independent_filters[i] =
|
||||
GetNumberOfIndependentFilters(input_spatial_lengths[i],
|
||||
input_left_pads[i],
|
||||
input_right_pads[i],
|
||||
filter_spatial_lengths[i],
|
||||
conv_filter_strides[i],
|
||||
conv_filter_dilations[i],
|
||||
image_offsets[i]);
|
||||
}
|
||||
const index_t independent_filters_acum = ck::accumulate_n<index_t>(
|
||||
independent_filters.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
if(independent_filters_acum <= 0)
|
||||
continue;
|
||||
|
||||
const auto in_grid_desc_m_k =
|
||||
MakeInputDescriptor_M_K(N,
|
||||
C,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
gemm_g_m_k_strides,
|
||||
independent_filters,
|
||||
effs);
|
||||
const auto out_grid_desc_m_k =
|
||||
MakeOutDescriptor_M_K(N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
image_offsets,
|
||||
independent_filters,
|
||||
effs);
|
||||
in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k);
|
||||
out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k);
|
||||
|
||||
const index_t x_idx = x_img_offset / conv_filter_strides[XIdx];
|
||||
const index_t y_idx = y_img_offset / conv_filter_strides[YIdx];
|
||||
const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx];
|
||||
|
||||
const index_t x_offset_with_pad =
|
||||
math::max(0, x_img_offset - input_left_pads[XIdx]);
|
||||
const index_t y_offset_with_pad =
|
||||
math::max(0, y_img_offset - input_left_pads[YIdx]);
|
||||
const index_t z_offset_with_pad =
|
||||
math::max(0, z_img_offset - input_left_pads[ZIdx]);
|
||||
|
||||
// Memory offsets to next set of independent filters,
|
||||
// move to independent filters in each dimension
|
||||
const index_t in_offset =
|
||||
(x_idx + y_idx * output_spatial_lengths[XIdx] +
|
||||
z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
|
||||
gemm_g_m_k_strides[I1];
|
||||
// Move to independent filters in appropriate dimensions
|
||||
const index_t out_offset =
|
||||
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
|
||||
y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] +
|
||||
z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx];
|
||||
|
||||
const InputDataType* p_in_with_offset =
|
||||
static_cast<const InputDataType*>(p_in) + in_offset;
|
||||
OutputDataType* p_out_with_offset =
|
||||
static_cast<OutputDataType*>(p_out) + out_offset;
|
||||
p_in_container_.push_back(p_in_with_offset);
|
||||
p_out_container_.push_back(p_out_with_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
std::cout << in_grid_desc_m_k_container_[i] << std::endl;
|
||||
std::cout << out_grid_desc_m_k_container_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
const ck::index_t G_;
|
||||
const ck::index_t C_;
|
||||
const ck::index_t X_;
|
||||
|
||||
const InputDataType* p_in_;
|
||||
OutputDataType* p_out_;
|
||||
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads_;
|
||||
|
||||
std::vector<InputGridDesc> in_grid_desc_m_k_container_;
|
||||
std::vector<OutputGridDesc> out_grid_desc_m_k_container_;
|
||||
|
||||
std::vector<const InputDataType*> p_in_container_;
|
||||
std::vector<OutputDataType*> p_out_container_;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
float elapsed_time = 0.f;
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
// Execute each set of independent filters
|
||||
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
arg.out_grid_desc_m_k_container_[i]);
|
||||
const index_t grid_size =
|
||||
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
|
||||
elapsed_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_m_k_container_[i],
|
||||
arg.p_in_container_[i],
|
||||
arg.out_grid_desc_m_k_container_[i],
|
||||
arg.p_out_container_[i],
|
||||
arg.G_,
|
||||
block_2_tile_map,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
}
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using namespace tensor_layout::convolution;
|
||||
if constexpr(!(is_NSpatialGC || is_GNSpatialC))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
|
||||
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
|
||||
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
|
||||
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
|
||||
bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
|
||||
bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
|
||||
|
||||
// check vector acces with c not packed
|
||||
if(!is_c_packed && ScalarPerVector != 1)
|
||||
return false;
|
||||
// check vector access of filter window row (only C if C is not packed)
|
||||
if(!is_w_packed && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of filter window row (X * C)
|
||||
if(arg.X_ * arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of pads (w_pad_left/w_pad_right * C)
|
||||
if(w_pad_left * arg.C_ % ScalarPerVector != 0 ||
|
||||
w_pad_right * arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of with stride and pad
|
||||
if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of with dilation
|
||||
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
bool valid = true;
|
||||
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
valid &= GridwiseTensorRearrangeKernel::CheckValidity(
|
||||
arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]);
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
return Argument{static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceColumnToImage"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< ScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,847 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename AsPointer,
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AsGridDesc_AK0_M_AK1,
|
||||
typename BsGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
|
||||
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_as_grid;
|
||||
ignore = p_bs_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = as_grid_desc_ak0_m_ak1;
|
||||
ignore = bs_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K]
|
||||
// input : B[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
: public DeviceContractionMultipleABD<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle;
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ComputeDataType = EDataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_,
|
||||
const std::vector<index_t>& a_ms_ks_strides_)
|
||||
{
|
||||
assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK &&
|
||||
a_ms_ks_strides_.size() == NumDimM + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number<NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number<NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAsGridDescriptor_M_K(const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]);
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
|
||||
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_,
|
||||
const std::vector<index_t>& b_ns_ks_strides_)
|
||||
{
|
||||
assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK &&
|
||||
b_ns_ks_strides_.size() == NumDimN + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number<NumDimN + NumDimK>{});
|
||||
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number<NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_N_K(const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]);
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
|
||||
// assume E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_,
|
||||
const std::vector<index_t>& e_ms_ns_strides_)
|
||||
{
|
||||
assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN &&
|
||||
e_ms_ns_strides_.size() == NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number<NumDimM + NumDimN>{});
|
||||
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number<NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto e_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
e_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeDsGridDescriptor_M_N(const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AsGridDesc_M_K = remove_cvref_t<decltype(MakeAsGridDescriptor_M_K({}, {}))>;
|
||||
using BsGridDesc_N_K = remove_cvref_t<decltype(MakeBsGridDescriptor_N_K({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N({}, {}))>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AsGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
|
||||
AsGridDesc_M_K{}))>;
|
||||
using BsGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
|
||||
BsGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(std::array<const void*, NumATensor> p_as_grid,
|
||||
std::array<const void*, NumBTensor> p_bs_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_length,
|
||||
const std::vector<index_t>& e_ms_ns_stride,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_as_grid_{},
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
as_grid_desc_m_k_{},
|
||||
bs_grid_desc_n_k_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)},
|
||||
as_grid_desc_ak0_m_ak1_{},
|
||||
bs_grid_desc_bk0_n_bk1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// populate pointer, desc for As
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
|
||||
// A pointer
|
||||
p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
|
||||
|
||||
// A desc
|
||||
as_grid_desc_m_k_(i) =
|
||||
MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
|
||||
});
|
||||
|
||||
// populate pointer, desc for Bs
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
|
||||
// B pointer
|
||||
p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
|
||||
|
||||
// B desc
|
||||
bs_grid_desc_n_k_(i) =
|
||||
MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
|
||||
});
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
|
||||
});
|
||||
|
||||
// populate desc for Ds/E
|
||||
if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_,
|
||||
bs_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
as_grid_desc_ak0_m_ak1_ =
|
||||
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
|
||||
|
||||
bs_grid_desc_bk0_n_bk1_ =
|
||||
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
// for sanity check of vector memory access
|
||||
for(index_t i = 0; i < NumATensor; ++i)
|
||||
{
|
||||
tie(as_continous_dim_[i], as_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumBTensor; ++i)
|
||||
{
|
||||
tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
tie(e_continous_dim_, e_max_write_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
|
||||
}
|
||||
|
||||
// pointers
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_;
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AsGridDesc_M_K as_grid_desc_m_k_;
|
||||
BsGridDesc_N_K bs_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_;
|
||||
BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
|
||||
std::array<index_t, NumATensor> as_continous_dim_;
|
||||
std::array<index_t, NumATensor> bs_continous_dim_;
|
||||
std::array<index_t, NumBTensor> ds_continous_dim_;
|
||||
index_t e_continous_dim_;
|
||||
|
||||
std::array<index_t, NumATensor> as_max_read_elems_;
|
||||
std::array<index_t, NumBTensor> bs_max_read_elems_;
|
||||
std::array<index_t, NumDTensor> ds_max_read_elems_;
|
||||
index_t e_max_write_elems_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
|
||||
arg.bs_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
typename GridwiseGemm::AsGridPointer,
|
||||
typename GridwiseGemm::BsGridPointer,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AsGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BsGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
arg.p_bs_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.as_grid_desc_ak0_m_ak1_,
|
||||
arg.bs_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
bool valid_as_access = true;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
const bool valid_a_vector_size =
|
||||
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m =
|
||||
ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
|
||||
const bool valid_a_access_dim_k =
|
||||
ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
|
||||
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
|
||||
if(!((valid_a_vector_size && valid_a_access_dim) ||
|
||||
ABlockTransferSrcScalarPerVector == 1))
|
||||
{
|
||||
valid_as_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_as_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_bs_access = true;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
const bool valid_b_vector_size =
|
||||
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n =
|
||||
BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
|
||||
const bool valid_b_access_dim_k =
|
||||
BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
|
||||
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
|
||||
if(!((valid_b_vector_size && valid_b_access_dim) ||
|
||||
BBlockTransferSrcScalarPerVector == 1))
|
||||
{
|
||||
valid_bs_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_bs_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_ds_access = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
|
||||
if(!((valid_d_vector_size && valid_d_access_dim) ||
|
||||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_ds_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
|
||||
if(!((valid_e_vector_size && valid_e_access_dim) ||
|
||||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
|
||||
arg.bs_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_length,
|
||||
const std::vector<index_t>& e_ms_ns_stride,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
d_ms_ns_lengths,
|
||||
d_ms_ns_strides,
|
||||
e_ms_ns_length,
|
||||
e_ms_ns_stride,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_length,
|
||||
const std::vector<index_t>& e_ms_ns_stride,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
as_ms_ks_lengths,
|
||||
as_ms_ks_strides,
|
||||
bs_ns_ks_lengths,
|
||||
bs_ns_ks_strides,
|
||||
ds_ms_ns_lengths,
|
||||
ds_ms_ns_strides,
|
||||
e_ms_ns_length,
|
||||
e_ms_ns_stride,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceContractionMultipleABD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">"
|
||||
<< " LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,780 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatDsPointer,
|
||||
typename FloatE,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_etile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
: public DeviceContractionMultipleD<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
|
||||
a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
|
||||
b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number<NumDimN + NumDimK>{});
|
||||
const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number<NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
// assume E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& e_ms_ns_strides_vec)
|
||||
{
|
||||
assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
|
||||
e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto num) {
|
||||
return generate_tuple([&](auto i) { return vec[i]; }, num);
|
||||
};
|
||||
|
||||
const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number<NumDimM + NumDimN>{});
|
||||
const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number<NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto e_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
e_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths_vec,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i],
|
||||
ds_ms_ns_strides_vec[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
});
|
||||
|
||||
// populate desc for Ds/E
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
// for sanity check of vector memory access
|
||||
tie(a_continous_dim_, a_max_read_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
tie(b_continous_dim_, b_max_read_elems_) =
|
||||
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
|
||||
}
|
||||
|
||||
tie(e_continous_dim_, e_max_write_elems_) =
|
||||
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
|
||||
index_t a_continous_dim_;
|
||||
index_t b_continous_dim_;
|
||||
std::array<index_t, NumDTensor> ds_continous_dim_;
|
||||
index_t e_continous_dim_;
|
||||
|
||||
index_t a_max_read_elems_;
|
||||
index_t b_max_read_elems_;
|
||||
std::array<index_t, NumDTensor> ds_max_read_elems_;
|
||||
index_t e_max_write_elems_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::Block2ETileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
"wrong!");
|
||||
|
||||
const bool valid_a_vector_size =
|
||||
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_a_access_dim_m =
|
||||
ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
|
||||
const bool valid_a_access_dim_k =
|
||||
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
|
||||
const bool valid_a_access_dim =
|
||||
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_a_vector_size && valid_a_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool valid_b_vector_size =
|
||||
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
|
||||
const bool valid_b_access_dim_n =
|
||||
BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
|
||||
const bool valid_b_access_dim_k =
|
||||
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
|
||||
const bool valid_b_access_dim =
|
||||
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
|
||||
if(!(valid_b_vector_size && valid_b_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid_ds_access = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
const bool valid_d_vector_size =
|
||||
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector read of Ds is always on N dimension.
|
||||
const bool valid_d_access_dim =
|
||||
arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_d_vector_size && valid_d_access_dim))
|
||||
{
|
||||
valid_ds_access = false;
|
||||
}
|
||||
});
|
||||
if(!valid_ds_access)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool valid_e_vector_size =
|
||||
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
|
||||
// Vector write of E is always on N dimension.
|
||||
const bool valid_e_access_dim =
|
||||
arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
|
||||
if(!(valid_e_vector_size && valid_e_access_dim))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
ds_ms_ns_lengths,
|
||||
ds_ms_ns_strides,
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_ms_ks_strides,
|
||||
const std::vector<index_t>& b_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
|
||||
const std::vector<index_t>& e_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
ds_ms_ns_lengths,
|
||||
ds_ms_ns_strides,
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceContractionMultipleD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< NumDimM << ", "
|
||||
<< NumDimN << ", "
|
||||
<< NumDimK << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
<< BBlockTransferSrcVectorDim
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,117 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* Calculates the maximum number of subsequent elements of the fast changing dimension
|
||||
* that are consecutive in memory.
|
||||
*
|
||||
* Example:
|
||||
* NumDimM = 2, NumDimK = 3
|
||||
* A shape = [ 2, 3, 4, 5, 6]
|
||||
* A strides = [360, 120, 30, 6, 1]
|
||||
* | M | | K |
|
||||
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
|
||||
* in memory.
|
||||
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
|
||||
* consecutive in memory.
|
||||
*
|
||||
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
|
||||
*/
|
||||
template <index_t NumDim1, index_t NumDim2>
|
||||
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
|
||||
{
|
||||
if(lengths.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of lengths in "
|
||||
<< "device_contraction_utils.hpp"
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
if(strides.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of strides in "
|
||||
<< "device_contraction_utils.hpp"
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
// Determine the beginning and end idx of the group representing the FCD.
|
||||
index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
|
||||
if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
|
||||
{
|
||||
// MZ or KZ are ones
|
||||
bool dims1_are_ones = true;
|
||||
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
|
||||
{
|
||||
if(lengths[dim_idx] != 1)
|
||||
{
|
||||
dims1_are_ones = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(dims1_are_ones)
|
||||
{
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
continous_dim = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
continous_dim = 0;
|
||||
}
|
||||
}
|
||||
else if(strides[NumDim1 - 1] == 1)
|
||||
{
|
||||
begin_idx = 0;
|
||||
end_idx = NumDim1 - 1;
|
||||
continous_dim = 0;
|
||||
}
|
||||
else if(strides[NumDim1 + NumDim2 - 1] == 1)
|
||||
{
|
||||
begin_idx = NumDim1;
|
||||
end_idx = NumDim1 + NumDim2 - 1;
|
||||
continous_dim = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// The dimension consecutive in memory is not the last dimension of any group, so only
|
||||
// one element can be read/written at once.
|
||||
consecutive_stride = 1;
|
||||
continous_dim = 0;
|
||||
return make_tuple(continous_dim, consecutive_stride);
|
||||
}
|
||||
|
||||
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
|
||||
{
|
||||
if(strides[dim_idx] == consecutive_stride)
|
||||
{
|
||||
consecutive_stride *= lengths[dim_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
const index_t max_subsequent_elems = consecutive_stride;
|
||||
return make_tuple(continous_dim, max_subsequent_elems);
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,808 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXdl,
|
||||
ck::index_t NPerXdl,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvBwdWeight<2,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using DeviceOp =
|
||||
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
using CDataType = WeiDataType;
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
using BElementwiseOperation = InElementwiseOperation;
|
||||
using CElementwiseOperation = WeiElementwiseOperation;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static constexpr auto N1Number = K1Number;
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
|
||||
// M1 & M0
|
||||
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
|
||||
static constexpr auto ABlockLdsM1Padding = 4;
|
||||
|
||||
// N1 & N0
|
||||
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
// A: output tensor
|
||||
const index_t N0 = N / N1Number;
|
||||
const index_t GemmK0Total = N0 * Ho * Wo;
|
||||
|
||||
const index_t GemmK0S =
|
||||
math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock;
|
||||
const index_t GemmK0Pad = GemmKBatch * GemmK0S;
|
||||
const auto out_n_ho_wo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
|
||||
|
||||
const auto out_n0_ho_wo_k_n1_grid_desc =
|
||||
transform_tensor_descriptor(out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
|
||||
make_pass_through_transform(Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmk0total_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk0total_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
|
||||
make_pass_through_transform(GemmM),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk0pad_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
|
||||
make_pass_through_transform(GemmM),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n0_y_ho_x_wo_c_n1_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
|
||||
make_pass_through_transform(Y),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(X),
|
||||
make_pass_through_transform(Wo),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 6>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n0_y_ho_x_wo_c_n1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0total_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
|
||||
make_pass_through_transform(GemmN),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0pad_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
|
||||
make_pass_through_transform(GemmN),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_c_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{in_element_op},
|
||||
c_element_op_{wei_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation a_element_op_;
|
||||
OutElementwiseOperation b_element_op_;
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::array<index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
void Print(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
Print(arg);
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
|
||||
}
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// unmerge N to N0 and N1, where N1 equals to K1
|
||||
if(!(arg.Conv_N_ % K1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,767 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXdl,
|
||||
ck::index_t NPerXdl,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvBwdData<2,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = InDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) %
|
||||
ABlockTransferSrcScalarPerVector ==
|
||||
0);
|
||||
static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) %
|
||||
BBlockTransferSrcScalarPerVector ==
|
||||
0);
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
index_t i_ytilde,
|
||||
index_t i_xtilde)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const auto K0 = K / K1;
|
||||
|
||||
const auto out_n_ho_wo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
|
||||
const auto wei_k_y_x_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_freeze_transform(I0),
|
||||
make_freeze_transform(I0),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilde);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilde);
|
||||
|
||||
const auto HTilde =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilde =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
|
||||
const auto IHTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
|
||||
const auto IWTildeSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildeSliceEnd = math::min(
|
||||
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildeSliceEnd = math::min(
|
||||
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
|
||||
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilde),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilde),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B weight tensor
|
||||
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilde),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilde),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilde, HTilde),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilde, WTilde),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(i_ytilde),
|
||||
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
|
||||
make_freeze_transform(i_xtilde),
|
||||
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildeslice_wtildeslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // function end
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_in_grid},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
// check slice is valid
|
||||
const index_t Y = filter_spatial_lengths_[0];
|
||||
const index_t X = filter_spatial_lengths_[1];
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
if(YDotSlice * XDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
i_ytilde,
|
||||
i_xtilde);
|
||||
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
|
||||
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
|
||||
c_grid_desc_m_n_container_.push_back(descs[I2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
|
||||
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
|
||||
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_;
|
||||
std::vector<ck::index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::index_t> output_spatial_lengths_;
|
||||
std::vector<ck::index_t> conv_filter_strides_;
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_container_{ "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto [gdx, gdy, gdz] =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::CGridDesc_M_N,
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]);
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
arg.c_grid_desc_m_n_container_[i]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,983 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] =
|
||||
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
}
|
||||
|
||||
using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(GridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(GridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I2])>;
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
|
||||
using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
|
||||
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
const OutDataType* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
p_c0_grid_{p_bias_grid},
|
||||
p_c1_grid_{p_resi_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c0_grid_desc_m_n_{},
|
||||
c1_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
c1_grid_desc_m_n_ = descs[I4];
|
||||
|
||||
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c0_grid_desc_m_n_);
|
||||
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c1_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const CDataType* p_c0_grid_;
|
||||
const CDataType* p_c1_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
const OutDataType* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
p_bias_grid,
|
||||
p_resi_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
const void* p_bias_grid,
|
||||
const void* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
static_cast<const OutDataType*>(p_bias_grid),
|
||||
static_cast<const OutDataType*>(p_resi_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,940 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] =
|
||||
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
InMemoryDataOperationEnum OutGlobalMemoryDataOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwdBiasActivation<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I3])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
OutGlobalMemoryDataOperation,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
p_c0_grid_{p_bias_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c0_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c0_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const CDataType* p_c0_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
p_bias_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
const void* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
static_cast<const OutDataType*>(p_bias_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,906 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXdl,
|
||||
ck::index_t NPerXdl,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwd<2,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType, // TODO: Add ShuffleType for DeviceConv2d
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock * K1,
|
||||
K1, // AK1
|
||||
K1, // BK1
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> input_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> conv_filter_dilations_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
<< arg.input_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
|
||||
<< arg.output_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
|
||||
<< arg.conv_filter_strides_[1] << ", " << std::endl;
|
||||
std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
|
||||
<< arg.conv_filter_dilations_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout
|
||||
<< "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
|
||||
"nwavenperxdl_{ "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I2)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I3)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I4)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I5)
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
Block2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user