mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '87dd073887933fc2c75c234871e3885cee970a98' into develop
This commit is contained in:
@@ -11,8 +11,11 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_fp16 grouped_conv_bwd_weight_v3_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_bf16 grouped_conv_bwd_weight_v3_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
using WeiDataType = F32;
|
||||
using OutDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
@@ -16,11 +16,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tensor_layout::convolution::GNDHWC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
@@ -32,30 +41,30 @@ using DeviceConvBwdWeightInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWMMA
|
||||
16, // NPerWMMA
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
4,
|
||||
2,
|
||||
S<1, 32, 1, 8>,
|
||||
1>;
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
@@ -80,6 +89,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
// Dl and WMMA ops don't support split_k > 1
|
||||
// Dl ops don't support split_k > 1
|
||||
constexpr ck::index_t split_k = 1;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
@@ -131,7 +131,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
|
||||
wei_device_buf.FromDevice(wei_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
double atol = ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k);
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData,
|
||||
wei_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else if(config.do_verification == 2)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,764 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_multi_d_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using EDataType = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}(
|
||||
[&](auto i) { splitk_batch_offset.a_k_split_offset[i] += a_batch_offset; });
|
||||
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}(
|
||||
[&](auto i) { splitk_batch_offset.b_k_split_offset[i] += b_batch_offset; });
|
||||
|
||||
splitk_batch_offset.c_reduce_offset += c_batch_offset;
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
// D pointer
|
||||
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
|
||||
: public DeviceBatchedGemmV2MultiD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
|
||||
using CDataType_ = EDataType;
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
ComputePtrOffsetOfStridedBatch(
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideA_) * g_idx;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideB_) * g_idx;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
std::array<long_index_t, GridwiseGemm::NumDTensor> ds_offset_;
|
||||
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset_[i] = static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
|
||||
});
|
||||
|
||||
return ds_offset_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideC_) * g_idx;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
|
||||
Argument() = default;
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t BatchStrideA_,
|
||||
index_t BatchStrideB_,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs_,
|
||||
index_t BatchStrideE_,
|
||||
index_t Batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
index_t KBatch_)
|
||||
: GridwiseGemm::Argument{std::array<const void*, 1>{p_a_grid_},
|
||||
std::array<const void*, 1>{p_b_grid_},
|
||||
p_ds_grid_,
|
||||
p_e_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
std::array<index_t, 1>{StrideA_},
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
KBatch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
cde_element_op_,
|
||||
false},
|
||||
Batch{Batch_},
|
||||
compute_ptr_offset_of_batch{
|
||||
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
|
||||
{
|
||||
}
|
||||
template <typename EType>
|
||||
void SetEPointer(void* ptr)
|
||||
{
|
||||
this->p_e_grid = static_cast<EType*>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_multi_d_wmma_cshuffle_v3<GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
gdy *= arg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[0] = arg_.Batch *
|
||||
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[0] = arg_.Batch *
|
||||
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_,
|
||||
arg_.compute_ptr_offset_of_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg,
|
||||
arg.compute_ptr_offset_of_batch);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported: Architecture must be gfx11/gfx12." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported splitK on gfx11." << std::endl;
|
||||
}
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported f8 / bf8 on gfx11." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported K dimension without padding." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, GridwiseGemm::NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultipleD_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(ELayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static ck::index_t GetMaxOccupancy()
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
return active_workgroups_per_cu.max_occupancy_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -350,6 +350,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
|
||||
{
|
||||
}
|
||||
template <typename EType>
|
||||
void SetEPointer(void* ptr)
|
||||
{
|
||||
this->p_c_grid = static_cast<EType*>(ptr);
|
||||
}
|
||||
};
|
||||
using Argument = ArgumentBase<GridwiseGemm64>;
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -807,7 +808,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -844,9 +845,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, // input
|
||||
@@ -915,7 +917,6 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
|
||||
@@ -32,7 +32,7 @@ template <ck::index_t NDimSpatial,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename DeviceGemmV3Op>
|
||||
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
struct DeviceGroupedConvBwdWeight_Explicit
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
@@ -56,7 +56,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
sizeof(WeiDataType) % 4 != 0 &&
|
||||
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit;
|
||||
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
|
||||
|
||||
static constexpr index_t ElementwiseBlockSize = 256;
|
||||
@@ -95,7 +95,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
using GemmArgument = typename DeviceGemmV3Op::Argument;
|
||||
|
||||
@@ -153,11 +153,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -170,12 +170,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
split_k_ = split_k;
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k_};
|
||||
k_batch_};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -236,7 +236,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k_};
|
||||
k_batch_};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
ck::index_t split_k_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -288,8 +287,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
// Modify to use workspace as output
|
||||
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
|
||||
explicit_gemm_args_with_workspace.p_c_grid =
|
||||
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
|
||||
explicit_gemm_args_with_workspace.template SetEPointer<TwoStageIntermediateType>(
|
||||
arg.p_workspace_);
|
||||
float avg_time =
|
||||
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
|
||||
const index_t grid_size =
|
||||
@@ -342,7 +341,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.split_k_ < 0)
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -353,6 +352,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -360,11 +363,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if constexpr(!is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -374,6 +385,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported stride / pad." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -381,6 +396,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if(!arg.is_filter_data_packed)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported: Filter data must be packed." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1745,6 +1745,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// TODO: this is needed because there is a bug
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -450,7 +451,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -490,8 +491,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
@@ -504,6 +504,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
k_batch_ = split_k;
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
@@ -576,7 +578,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -295,7 +295,7 @@ struct ABTransferThreadTiles
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
const index_t k_id)
|
||||
{
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
const index_t mn_block_data_idx_on_grid =
|
||||
@@ -304,7 +304,7 @@ struct ABTransferThreadTiles
|
||||
if constexpr(NumABTensor > 1)
|
||||
{
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
|
||||
[&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); },
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
@@ -357,7 +357,7 @@ struct ABTransferThreadTiles
|
||||
ABThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
GlobalBufferNum>(grid_descriptor[I0],
|
||||
make_multi_index(0, mn_block_data_idx_on_grid, 0),
|
||||
make_multi_index(k_id, mn_block_data_idx_on_grid, 0),
|
||||
ab_element_op,
|
||||
block_descriptor,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
@@ -409,6 +410,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument() = default;
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
@@ -583,7 +585,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -651,7 +654,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -700,7 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
Argument& karg,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -735,7 +740,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
@@ -748,20 +754,146 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(
|
||||
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id);
|
||||
EpilogueArgument>(p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
DefaultBlock2CTileMap(karg),
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
}
|
||||
|
||||
// Run method for convolution (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
index_t NumGroupsToMerge,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -723,7 +723,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -793,7 +794,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
@@ -806,7 +808,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -857,7 +860,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
epilogue_args,
|
||||
0, /* A_k_id == 0 (we shift the pointer for splitk) */
|
||||
k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -344,11 +349,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
@@ -706,8 +720,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
ReduceTrait>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n,
|
||||
index_t MBlock,
|
||||
index_t NBlock)
|
||||
{
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
@@ -1004,6 +1020,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
// Note: arguments k_batch and k_id should be set if splitk is used
|
||||
// with implicit gemm (no pointer shift but shift using tensor descriptors)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -1034,7 +1052,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -1066,7 +1086,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AsDataType,
|
||||
AElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id);
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, A_k_id);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
@@ -1075,7 +1095,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BsDataType,
|
||||
BElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id);
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, B_k_id);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1100,7 +1120,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
|
||||
@@ -71,6 +71,29 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float4_t atomic_add<float4_t>(float4_t* p_dst, const float4_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const vector_type<float, 4> vx{x};
|
||||
vector_type<float, 4> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
vy.template AsType<float>()(I2) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, vx.template AsType<float>()[I2]);
|
||||
vy.template AsType<float>()(I3) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, vx.template AsType<float>()[I3]);
|
||||
|
||||
return vy.template AsType<float4_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/utility/functional2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -32,17 +32,17 @@ void add_explicit_gemm_device_operation_instances(
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemmV3Ops>, 1>{}([&](auto i) {
|
||||
using DeviceGemmOp = std::tuple_element_t<i, DeviceGemmV3Ops>;
|
||||
|
||||
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
DeviceGemmOp>;
|
||||
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
DeviceGemmOp>;
|
||||
|
||||
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
|
||||
"wrong! NewOpInstance should be derived from BaseOp");
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <typename InOutDataType>
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 96, 64, 8, 8, 16, 16, 3, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 128, 8, 8, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 48, 96, 192, 8, 8, 16, 16, 3, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 96, 128, 64, 8, 8, 16, 16, 6, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InOutDataType, BlockGemmPipelineScheduler BlkGemmPipeSched>
|
||||
using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 128, 8, 8, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 48, 128, 8, 8, 16, 16, 1, 3, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 64, 8, 8, 16, 16, 1, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 128, 8, 8, 16, 16, 1, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 1, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 128, 8, 8, 16, 16, 1, 6, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 192, 32, 8, 8, 16, 16, 1, 12, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 256, 96, 64, 8, 8, 16, 16, 2, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
|
||||
// Memory friendly
|
||||
// TODO: add once v2 is implemented
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| |
|
||||
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, // Incorrect results for at least GemmDefault
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> // Incorrect results for at least GemmDefault
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| |
|
||||
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
|
||||
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>
|
||||
// DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, // Incorrect results for at least GemmDefault
|
||||
// DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> // Incorrect results for at least GemmDefault
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
|
||||
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>
|
||||
//clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for f16
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for f16
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// other instances
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,117 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 64, 8, 8, 16, 16, 2, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 16, 8, 8, 16, 16, 2, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 256, 8, 8, 16, 16, 4, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 256, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 128, 8, 8, 16, 16, 8, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 128, 8, 8, 16, 16, 2, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 64, 8, 8, 16, 16, 8, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 64, 8, 8, 16, 16, 4, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for fp16
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for fp16
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// other instances
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -21,6 +21,7 @@
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "grouped_convolution_backward_weight_wmma.inc"
|
||||
#include "grouped_convolution_backward_weight_explicit_wmma.inc"
|
||||
#endif
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -414,21 +415,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -471,23 +475,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -678,21 +682,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -735,23 +742,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -850,35 +857,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
@@ -889,26 +914,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -17,6 +17,39 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<F32>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
@@ -148,6 +181,35 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// 2D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// 3D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@ namespace instance {
|
||||
// 2D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -34,7 +34,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -46,7 +46,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -70,7 +70,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -82,7 +82,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -94,7 +94,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -106,7 +106,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -121,7 +121,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -133,7 +133,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -145,7 +145,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -157,7 +157,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -169,7 +169,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -181,7 +181,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -193,7 +193,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -205,7 +205,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -217,7 +217,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -232,7 +232,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
// 3D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -244,7 +244,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -256,7 +256,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -268,7 +268,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -280,7 +280,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -292,7 +292,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -304,7 +304,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -316,7 +316,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -328,7 +328,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -343,7 +343,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -355,7 +355,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -367,7 +367,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -379,7 +379,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -391,7 +391,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -403,7 +403,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -415,7 +415,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -427,7 +427,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -439,7 +439,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -17,6 +17,40 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
@@ -147,6 +181,34 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
|
||||
@@ -8,32 +8,61 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -46,7 +75,7 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,51 +87,28 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_AND_DL_KERNELS
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
set(GROUPED_CONV1D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_AND_DL_KERNELS
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
@@ -72,4 +72,11 @@ if(DL_KERNELS)
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GROUPED_CONV2D_BWD_WEIGHT
|
||||
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
|
||||
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT})
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,17 +2,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
@@ -20,13 +22,17 @@ void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
@@ -69,14 +69,11 @@ if(DL_KERNELS)
|
||||
endif()
|
||||
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp)
|
||||
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
|
||||
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,31 +2,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -2,13 +2,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -20,13 +22,17 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -2,31 +2,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -2,12 +2,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -20,13 +22,14 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
@@ -13,4 +13,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_weight_bilinear_instance ${GROUPED_CONV3D_BWD_WEIGHT_BILINEAR})
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<F32>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT_SCALE
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
@@ -13,4 +13,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_weight_scale_instance ${GROUPED_CONV3D_BWD_WEIGHT_SCALE})
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,29 +1,37 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONVND_EXP_BWD_WEIGHT
|
||||
# Explicit instances are common for 2d and 3d
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp
|
||||
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instance.cpp
|
||||
|
||||
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp
|
||||
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp
|
||||
|
||||
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp
|
||||
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp
|
||||
)
|
||||
add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT})
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances<BF16>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances<BF16>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<BF16, Intrawave>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<BF16, Intrawave>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances<F16>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances<F16>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<F16, Intrawave>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances<F16, Intrawave>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -37,7 +37,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,7 +36,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -214,12 +214,30 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
auto split_k_value = split_k_list[split_k_id];
|
||||
auto split_k_param_str = std::to_string(split_k_value);
|
||||
auto* split_k_arg =
|
||||
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
|
||||
if(split_k_arg && split_k_value < 0)
|
||||
|
||||
// If split_k was determined by the device implementation, get the resulting value.
|
||||
if(split_k_value < 0)
|
||||
{
|
||||
split_k_value = split_k_arg->k_batch_;
|
||||
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
|
||||
auto* split_k_arg =
|
||||
dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
|
||||
if(split_k_arg)
|
||||
{
|
||||
split_k_value = split_k_arg->k_batch_;
|
||||
split_k_param_str = std::to_string(split_k_value) + " (best occupancy)";
|
||||
}
|
||||
else
|
||||
{
|
||||
// We may have an implementation whose argument is not derived from
|
||||
// ArgumentSplitK, which means we can not determine the splitK value. Warn.
|
||||
printf("Warning: Unable to determine split_k value for this instance!\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Not all device implementation actually do anything with the passed split_k value but
|
||||
// it needs to be positive to determine error tolerances.
|
||||
if(split_k_value < 0)
|
||||
{
|
||||
split_k_value = 1;
|
||||
}
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
@@ -297,12 +315,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
std::cout << "Fail info: splitK: " << split_k_value << " "
|
||||
<< op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
all_pass &= pass;
|
||||
@@ -330,6 +349,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
printf("\033[36mvalids: %d\033[0m\n", num_kernel);
|
||||
|
||||
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
|
||||
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
|
||||
<< "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl;
|
||||
|
||||
@@ -209,9 +209,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
|
||||
@@ -238,7 +235,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
@@ -251,6 +247,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
endif()
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -5,16 +5,19 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_bilinear test_grouped_convnd_bwd_weight_bilinear.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_scale test_grouped_convnd_bwd_weight_scale.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_scale PRIVATE utility device_grouped_conv3d_bwd_weight_scale_instance)
|
||||
|
||||
add_executable(test_grouped_convnd_bwd_weight_dataset_xdl test_grouped_convnd_bwd_weight_dataset_xdl.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
|
||||
elseif(DL_KERNELS)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
elseif(GPU_TARGETS MATCHES "gfx11")
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility)
|
||||
@@ -27,7 +30,3 @@ add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_conv_bwd_weight_xdl_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance)
|
||||
endif()
|
||||
|
||||
@@ -46,44 +46,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
// on gfx11 only support for 3d is implemented
|
||||
if constexpr(NDimSpatial{} != 3)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// on gfx11 only support for i8 and fp16 is implemented
|
||||
if constexpr(!((std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> &&
|
||||
std::is_same_v<OutDataType, int8_t>) ||
|
||||
(std::is_same_v<InDataType, ck::half_t> &&
|
||||
std::is_same_v<WeiDataType, ck::half_t> &&
|
||||
std::is_same_v<OutDataType, ck::half_t>)))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// WMMA kernel is only supported for split_k=1
|
||||
if(split_k != 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// Skip due to the lack of kernels for NGCDHW
|
||||
if constexpr(std::is_same_v<InLayout, NGCW> || std::is_same_v<InLayout, NGCHW> ||
|
||||
std::is_same_v<InLayout, NGCDHW>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// support for i8 is only implemented on gfx11
|
||||
if constexpr(std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> && std::is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -212,7 +212,34 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
}
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr});
|
||||
wei_device_buf.FromDevice(wei_device.mData.data());
|
||||
passed &= ck::utils::check_err(wei_device, wei_host, "Error: incorrect results!");
|
||||
|
||||
using AccDataType = float;
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol =
|
||||
ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
double atol =
|
||||
ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums / num_accums_split_k);
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
num_accums_split_k);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
passed &= ck::utils::check_err(
|
||||
wei_device, wei_host, "Error: incorrect results!", rtol, atol);
|
||||
|
||||
std::size_t flop =
|
||||
conv_param.GetFlops() +
|
||||
@@ -236,6 +263,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
std::cout << "grouped_conv_bwd_weight_instance (" << instance_index << "/" << num_kernel
|
||||
<< "): Passed" << std::endl;
|
||||
}
|
||||
printf("\033[36mvalids: %d\033[0m\n", num_kernel);
|
||||
return passed;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <typeinfo>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using InDataType = std::tuple_element_t<0, Tuple>;
|
||||
using WeiDataType = std::tuple_element_t<1, Tuple>;
|
||||
using OutDataType = std::tuple_element_t<2, Tuple>;
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{};
|
||||
static constexpr float alpha = 2.f;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
|
||||
void RunReference(ck::utils::conv::ConvParam& conv_param,
|
||||
ck::Tensor<InDataType>& in,
|
||||
ck::Tensor<WeiDataType>& wei_host,
|
||||
ck::Tensor<OutDataType>& out)
|
||||
{
|
||||
auto ref_conv =
|
||||
ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
0, /*Num A Elementwise Tensors*/
|
||||
0, /*Num B Elementwise Tensors*/
|
||||
0> /*Num D Elementwise Tensors*/
|
||||
{};
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
wei_host,
|
||||
out,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{alpha},
|
||||
OutElementOp{},
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
bool PerformConvWeightScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k)
|
||||
{
|
||||
bool passed = true;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
ck::Tensor<InDataType> in(in_g_n_c_wis_desc);
|
||||
ck::Tensor<OutDataType> out(out_g_n_k_wos_desc);
|
||||
ck::Tensor<WeiDataType> wei_host(wei_g_k_c_xs_desc);
|
||||
ck::Tensor<WeiDataType> wei_device(wei_g_k_c_xs_desc);
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "wei: " << wei_host.mDesc << std::endl;
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
|
||||
ck::DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device.mDesc.GetElementSpaceSize());
|
||||
in_device_buf.ToDevice(in.mData.data());
|
||||
wei_device_buf.ToDevice(wei_device.mData.data());
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), b_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), b_g_n_c_wis_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), e_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), e_g_k_c_xs_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
RunReference(conv_param, in, wei_host, out);
|
||||
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ck::Tuple<>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::Tuple<>,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
int num_kernel = 0;
|
||||
|
||||
for(std::size_t i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 0>{},
|
||||
b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{alpha},
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
ck::DeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
num_kernel++;
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr});
|
||||
wei_device_buf.FromDevice(wei_device.mData.data());
|
||||
|
||||
using AccDataType = float;
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol =
|
||||
ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
double atol =
|
||||
ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums / num_accums_split_k);
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
num_accums_split_k);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
passed &= ck::utils::check_err(
|
||||
wei_device, wei_host, "Error: incorrect results!", rtol, atol);
|
||||
|
||||
std::size_t flop =
|
||||
conv_param.GetFlops() +
|
||||
3 * conv_param.GetOutputByte<WeiDataType>() / sizeof(WeiDataType);
|
||||
std::size_t num_bytes = conv_param.GetByte<InDataType, WeiDataType, OutDataType>() +
|
||||
conv_param.GetOutputByte<WeiDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
printf("\033[36mvalids: %d\033[0m\n", num_kernel);
|
||||
return passed;
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
|
||||
for(auto split_k : split_ks)
|
||||
{
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && PerformConvWeightScale(param, split_k);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using KernelTypes3d =
|
||||
::testing::Types<std::tuple<float, float, float, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 128, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 128, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 128, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->Run();
|
||||
}
|
||||
Reference in New Issue
Block a user