mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Grouped Convolution Backward Weight Explicit GEMM (#2282)
* Grouped conv bwd weight explicit gemm * 3d * cmake fixes * fix test * fix
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -242,6 +242,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
@@ -282,7 +283,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
const std::array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
@@ -291,6 +292,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
|
||||
Argument() = default;
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
@@ -413,19 +415,39 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename DeviceGemmV3Op>
|
||||
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
static_assert(is_same_v<InElementwiseOperation, element_wise::PassThrough>);
|
||||
static_assert(is_same_v<WeiElementwiseOperation, element_wise::PassThrough>);
|
||||
static_assert(is_same_v<OutElementwiseOperation, element_wise::PassThrough>);
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
using GemmArgument = typename DeviceGemmV3Op::Argument;
|
||||
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>&, // input
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>&,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
: filter_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const index_t M = e_g_k_c_xs_lengths[I1];
|
||||
const index_t N = e_g_k_c_xs_lengths[I2];
|
||||
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
|
||||
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
|
||||
|
||||
explicit_gemm_args = GemmArgument{p_out_grid,
|
||||
p_in_grid,
|
||||
{},
|
||||
p_wei_grid,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BatchSize * M,
|
||||
BatchSize * N,
|
||||
{},
|
||||
N,
|
||||
M,
|
||||
N,
|
||||
{},
|
||||
M * N,
|
||||
BatchSize,
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
|
||||
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
}
|
||||
|
||||
GemmArgument explicit_gemm_args;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
typename DeviceGemmV3Op::Invoker explicit_gemm_op;
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Gridwise GEMM size
|
||||
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
b_g_n_c_wis_lengths, // input
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_lengths, // weight
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_lengths, // output
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
b_g_n_c_wis_lengths, // input
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_lengths, // weight
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_lengths, // output
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeight_Explicit_Xdl"
|
||||
<< "<" << DeviceGemmV3Op{}.GetTypeString() << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -391,53 +391,53 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
using CElementwiseGridDesc_M_N =
|
||||
remove_cvref_t<decltype(GetElementwiseCGridDesc<NDimSpatial>())>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
|
||||
@@ -328,53 +328,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
K1,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
K1,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
|
||||
Reference in New Issue
Block a user