This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,25 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
set(EXAMPLE_CONV_COMPILE_OPTIONS)
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
add_executable(tile_example_grouped_conv_fwd grouped_convolution_forward.cpp)
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_fwd_large_tensor grouped_convolution_forward_large_tensor.cpp)
target_compile_options(tile_example_grouped_conv_fwd_large_tensor PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_fwd_bias_clamp grouped_convolution_forward_bias_clamp.cpp)
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight_two_stage grouped_convolution_backward_weight_two_stage.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_data grouped_convolution_backward_data.cpp)
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,349 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <variant>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
struct ConvConfigBase
{
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
template <typename PrecType>
struct ConvConfigMemoryInterwave : public ConvConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct ConvConfigMemoryIntrawave : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
};
template <typename PrecType>
struct ConvConfigComputeV3 : public ConvConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};
template <typename PrecType>
struct ConvConfigComputeV3_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};
template <typename PrecType>
struct ConvConfigComputeV3_2 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct ConvConfigComputeV4 : public ConvConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
};
template <typename PrecType>
struct ConvConfigComputeV4_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
};
template <typename PrecType>
struct ConvConfigComputeV5 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>
struct ConvConfigComputeV6 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
static constexpr ck_tile::index_t NumWaveGroups = 1;
};
template <typename PrecType>
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
{
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
};
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
struct ConvTypeConfig;
template <>
struct ConvTypeConfig<ck_tile::half_t>
{
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using AccDataType = float;
using OutDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using InDataType = ck_tile::bf16_t;
using WeiDataType = ck_tile::bf16_t;
using AccDataType = float;
using OutDataType = ck_tile::bf16_t;
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::BASIC_V1>
{
template <typename PipelineProblem>
using GemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem,
ck_tile::GroupedConvUniversalPipelineAgBgCrPolicy>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::BASIC_V2>
{
template <typename PipelineProblem>
using GemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV2<PipelineProblem,
ck_tile::GroupedConvUniversalPipelineAgBgCrPolicy>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV2<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline =
ck_tile::GemmPipelineAgBgCrMem<PipelineProblem,
ck_tile::GroupedConvUniversalPipelineAgBgCrPolicy>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline =
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem,
ck_tile::GroupedConvUniversalPipelineAgBgCrPolicy>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};

View File

@@ -0,0 +1,58 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_data_invoker.hpp"
#include "run_grouped_convolution_bwd_data_example.inc"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionBackwardDataInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionBackwardDataInvoker
{
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -0,0 +1,67 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,143 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
InDataType,
DsDataType,
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(args.k_batch > 1)
{
ck_tile::hip_check_error(hipMemsetAsync(
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
}
};
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return InvokerResult{ave_time, args.k_batch};
}
};

View File

@@ -0,0 +1,68 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
#include "conv_configs.hpp"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,207 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static InvokerResult grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
using WorkspaceDataType = float;
// Force Vector Size C to 1 for two stage to check main
// two stage use case
constexpr ck_tile::index_t VectorSizeC = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
InDataType, // B: In
DsDataType,
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
const auto kargs = Kernel::MakeKernelArgs(ws_args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
WorkspaceDataType,
WeiDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(args.G_ * args.K_),
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(args.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s.stream_id_));
};
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return InvokerResult{ave_time, kargs.k_batch};
}
};

View File

@@ -0,0 +1,66 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_bias_clamp_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_bias_clamp_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,143 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Regular grouped convolution invoker (no split-image)
// This invoker demonstrates regular convolution without split-image.
// It always uses Kernel<false> (split-image disabled).
// For large images that require split-image, use
// grouped_convolution_forward_split_image_invoker.hpp
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
constexpr auto scheduler = ConvConfig::Scheduler;
// =====================================================================
// Regular Convolution: Simple, no split-image
// =====================================================================
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};

View File

@@ -0,0 +1,63 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Large tensor grouped convolution example
// This example demonstrates convolution for large tensors that exceed memory limits.
// It uses automatic tensor splitting when needed to handle large images.
// For regular convolution without tensor splitting, use grouped_convolution_forward.cpp
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_forward_large_tensor_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,340 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDEElementWise>& args,
const ck_tile::stream_config& s)
{
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsTypeDefault =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge,
true /*EnableSplitImage*/>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB,
GroupedConvTraitsTypeDefault::VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
OutDataType>;
// =====================================================================
// Step 1: Check if layout supports split-image kernel
// =====================================================================
// Split-image requires specific memory layouts:
// 1D: NWGC (input), GKXC (weight), NWGK (output)
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
constexpr bool is_supported_layout =
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NWGC>::value ||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NHWGC>::value ||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>::value;
// =====================================================================
// Step 2: Calculate split-image info (if layout supports it)
// =====================================================================
// Extract output spatial dimensions
const ck_tile::index_t total_d =
(NDimSpatial == 3) ? args.output_spatial_lengths_[NDimSpatial - 3] : 1;
const ck_tile::index_t total_h =
(NDimSpatial >= 2) ? args.output_spatial_lengths_[NDimSpatial - 2] : 1;
const ck_tile::index_t total_w = args.output_spatial_lengths_[NDimSpatial - 1];
auto split_info = TransformType::GetSplitImageInfo(
args.G_, args.N_, args.C_, args.K_, total_d, total_h, total_w);
// =====================================================================
// Decide: Split-image or regular kernel?
// =====================================================================
const bool use_split_image = is_supported_layout && split_info.should_split;
if(s.log_level_ > 0)
{
if(!is_supported_layout)
{
std::cout << "[INVOKER] Layout not supported for split-image. "
<< "Using regular kernel (Kernel<false>).\n";
}
else if(!split_info.should_split)
{
std::cout << "[INVOKER] Image is small (" << total_h << "×" << total_w
<< "), split-image not necessary.\n";
std::cout << "[INVOKER] Using regular kernel (Kernel<false>).\n";
}
}
// =====================================================================
// Step 3: Calculate split-image pieces (only if using split-image)
// =====================================================================
ck_tile::index_t num_d_pieces = 1;
ck_tile::index_t num_h_pieces = 1;
ck_tile::index_t num_w_pieces = 1;
ck_tile::index_t total_pieces = 1;
ck_tile::index_t base_piece_d = total_d;
ck_tile::index_t base_piece_h = total_h;
ck_tile::index_t base_piece_w = total_w;
std::array<ck_tile::SplitImagePieceInfo, 64> temp_pieces{};
ck_tile::index_t total_blocks = 0;
if(use_split_image)
{
num_d_pieces = split_info.num_d_pieces;
num_h_pieces = split_info.num_h_pieces;
num_w_pieces = split_info.num_w_pieces;
total_pieces = num_d_pieces * num_h_pieces * num_w_pieces;
if(s.log_level_ > 0)
{
std::cout << "\n========================================\n";
std::cout << "[SPLIT-IMAGE ENABLED] Large tensor detected\n";
std::cout << "========================================\n";
if(NDimSpatial == 3)
{
std::cout << "Total dimensions: D=" << total_d << " H=" << total_h
<< " W=" << total_w << "\n";
std::cout << "Split into pieces: D=" << num_d_pieces << " × H=" << num_h_pieces
<< " × W=" << num_w_pieces << " = " << total_pieces
<< " total pieces\n";
std::cout << "Base piece size: D=" << (total_d / num_d_pieces)
<< " H=" << (total_h / num_h_pieces)
<< " W=" << (total_w / num_w_pieces) << "\n";
}
else if(NDimSpatial == 2)
{
std::cout << "Total dimensions: H=" << total_h << " W=" << total_w << "\n";
std::cout << "Split into pieces: H=" << num_h_pieces << " × W=" << num_w_pieces
<< " = " << total_pieces << " total pieces\n";
std::cout << "Base piece size: H=" << (total_h / num_h_pieces)
<< " W=" << (total_w / num_w_pieces) << "\n";
}
else
{
std::cout << "Total dimensions: W=" << total_w << "\n";
std::cout << "Split into pieces: W=" << num_w_pieces << " = " << total_pieces
<< " total pieces\n";
std::cout << "Base piece size: W=" << (total_w / num_w_pieces) << "\n";
}
std::cout << "========================================\n\n";
}
// Base piece size (non-overlapping division)
base_piece_d = total_d / num_d_pieces;
base_piece_h = total_h / num_h_pieces;
base_piece_w = total_w / num_w_pieces;
// Calculate piece info for all pieces using library utility function
for(ck_tile::index_t piece = 0; piece < total_pieces; piece++)
{
temp_pieces[piece] =
ck_tile::calculate_spatial_piece<TilePartitioner>(piece,
num_d_pieces,
num_h_pieces,
num_w_pieces,
base_piece_d,
base_piece_h,
base_piece_w,
total_d,
total_h,
total_w,
args.N_,
args.K_,
total_blocks);
total_blocks = temp_pieces[piece].block_end;
}
}
constexpr auto scheduler = ConvConfig::Scheduler;
// =====================================================================
// Kernel launch lambda: Uses EnableSplitImage based on layout support
// =====================================================================
const auto Run = [&](const auto enable_split_image_) {
constexpr bool EnableSplitImage = enable_split_image_.value;
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
// Create kargs
auto kargs = Kernel::MakeKernelArgs(args);
// Populate split-image metadata ONLY if using split-image kernel
if constexpr(EnableSplitImage)
{
kargs.num_spatial_pieces = total_pieces;
kargs.split_image.total_d = total_d;
kargs.split_image.total_h = total_h;
kargs.split_image.total_w = total_w;
kargs.split_image.total_spatial = total_d * total_h * total_w; // Pre-calculate
kargs.split_image.num_d_pieces = num_d_pieces;
kargs.split_image.num_h_pieces = num_h_pieces;
kargs.split_image.num_w_pieces = num_w_pieces;
for(ck_tile::index_t i = 0; i < total_pieces; i++)
{
kargs.split_image.pieces[i] = {temp_pieces[i].block_start,
temp_pieces[i].block_end,
temp_pieces[i].d_start,
temp_pieces[i].h_start,
temp_pieces[i].w_start,
temp_pieces[i].d_size,
temp_pieces[i].h_size,
temp_pieces[i].w_size};
}
}
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
const dim3 grids = [&]() {
if constexpr(EnableSplitImage)
return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
else
return Kernel::GridSize(kargs);
}();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
// =====================================================================
// Step 4: Dispatch kernel (split-image or regular based on decision)
// =====================================================================
if(use_split_image)
{
return Run(ck_tile::bool_constant<true>{});
}
else
{
return Run(ck_tile::bool_constant<false>{});
}
}
};

View File

@@ -0,0 +1,135 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "conv_configs.hpp"
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
ck_tile::integer_divide_ceil(GemmK, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
const auto atol_split_k =
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
std::vector<ck_tile::index_t>& image_spatial_lengths,
std::vector<ck_tile::index_t>& strides,
std::vector<ck_tile::index_t>& dilations,
std::vector<ck_tile::index_t>& lpads,
std::vector<ck_tile::index_t>& rpads,
ck_tile::ArgParser& arg_parser)
{
constexpr ck_tile::index_t non_sp_dims = 3;
const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims;
if(!(n_dim_sp >= 1 && n_dim_sp <= 3))
{
throw std::runtime_error("Wrong layout!\n");
}
if(n_dim_sp == 3)
{
filter_spatial_lengths.push_back(arg_parser.get_int("z"));
image_spatial_lengths.push_back(arg_parser.get_int("d"));
strides.push_back(arg_parser.get_int("stride_d"));
dilations.push_back(arg_parser.get_int("dilation_d"));
lpads.push_back(arg_parser.get_int("lpad_d"));
rpads.push_back(arg_parser.get_int("rpad_d"));
}
if(n_dim_sp >= 2)
{
filter_spatial_lengths.push_back(arg_parser.get_int("y"));
image_spatial_lengths.push_back(arg_parser.get_int("h"));
strides.push_back(arg_parser.get_int("stride_h"));
dilations.push_back(arg_parser.get_int("dilation_h"));
lpads.push_back(arg_parser.get_int("lpad_h"));
rpads.push_back(arg_parser.get_int("rpad_h"));
}
filter_spatial_lengths.push_back(arg_parser.get_int("x"));
image_spatial_lengths.push_back(arg_parser.get_int("w"));
strides.push_back(arg_parser.get_int("stride_w"));
dilations.push_back(arg_parser.get_int("dilation_w"));
lpads.push_back(arg_parser.get_int("lpad_w"));
rpads.push_back(arg_parser.get_int("rpad_w"));
return n_dim_sp;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("g", "2", "group dimension")
.insert("n", "32", "n dimension")
.insert("k", "32", "k dimension")
.insert("c", "32", "c dimension")
.insert("d", "64", "d dimension")
.insert("h", "64", "h dimension")
.insert("w", "64", "w dimension")
.insert("z", "4", "z dimension")
.insert("y", "4", "y dimension")
.insert("x", "4", "x dimension")
.insert("stride_d", "1", "d stride")
.insert("stride_h", "1", "h stride")
.insert("stride_w", "1", "w stride")
.insert("dilation_d", "1", "d dilation")
.insert("dilation_h", "1", "h dilation")
.insert("dilation_w", "1", "w dilation")
.insert("lpad_d", "0", "left pad for d dimension")
.insert("lpad_h", "0", "left pad for h dimension")
.insert("lpad_w", "0", "left pad for w dimension")
.insert("rpad_d", "0", "right pad for d dimension")
.insert("rpad_h", "0", "right pad for h dimension")
.insert("rpad_w", "0", "right pad for w dimension")
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
.insert("wei_layout", "GKYXC", "Weight layout - GKYXC by default")
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
struct InvokerResult
{
float ave_time;
ck_tile::index_t split_k;
};

View File

@@ -0,0 +1,291 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp"
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_data<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_bwd_data_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<WeiDataType>{-1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
}
else
{
weight.SetZero();
output.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.SetZero();
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
input_dev_buf.FromDevice(input.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
input_host_ref.SetZero();
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(input,
input_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
// GPU reference verification
ck_tile::DeviceMem input_ref_dev_buf(input.get_element_space_size_in_bytes());
input_ref_dev_buf.SetZero();
// Launch GPU reference kernel
std::cout << "Run GPU reference kernel..." << std::endl;
ck_tile::naive_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
reinterpret_cast<InDataType*>(input_ref_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const WeiDataType*>(weight_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const OutDataType*>(output_dev_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.output_spatial_lengths_,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_);
// Copy GPU reference result to host for comparison
ck_tile::HostTensor<InDataType> input_gpu_ref(in_g_n_c_wis_desc);
input_ref_dev_buf.FromDevice(input_gpu_ref.data());
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_gpu_ref.mData.begin(), input_gpu_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(input,
input_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename Invoker,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_data_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}

View File

@@ -0,0 +1,300 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp"
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
InvokerResult invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
int n_warmup,
int n_repeat)
{
auto res = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return res;
}
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_parser,
const InLayout,
const WeiLayout,
const OutLayout)
{
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-1.f, 1.f}(input);
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(input);
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
}
else
{
input.SetZero();
output.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.ToDevice(input.data());
weight_dev_buf.SetZero();
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdWeightHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Weight kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
auto res = invoke_grouped_conv_bwd_weight<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
const float ave_time = res.ave_time;
weight_dev_buf.FromDevice(weight.data());
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<WeiDataType> weight_host_ref(wei_g_k_c_xs_desc);
weight_host_ref.SetZero();
ck_tile::
reference_grouped_conv_bwd_weight<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input,
weight_host_ref,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
const ck_tile::index_t split_k = res.split_k;
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, split_k, max_accumulated_value);
pass = ck_tile::check_err(weight,
weight_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
// GPU reference verification
ck_tile::DeviceMem weight_ref_dev_buf(weight.get_element_space_size_in_bytes());
weight_ref_dev_buf.SetZero();
// Launch GPU reference kernel
std::cout << "Run GPU reference kernel..." << std::endl;
ck_tile::naive_grouped_conv_bwd_weight<NDimSpatial, InDataType, WeiDataType, OutDataType>(
reinterpret_cast<const InDataType*>(input_dev_buf.GetDeviceBuffer()),
reinterpret_cast<WeiDataType*>(weight_ref_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const OutDataType*>(output_dev_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.output_spatial_lengths_,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_);
// Copy GPU reference result to host for comparison
ck_tile::HostTensor<WeiDataType> weight_gpu_ref(wei_g_k_c_xs_desc);
weight_ref_dev_buf.FromDevice(weight_gpu_ref.data());
ck_tile::index_t GemmK = conv_param.N_;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
GemmK *= conv_param.output_spatial_lengths_[i];
}
const float max_accumulated_value =
*std::max_element(weight_gpu_ref.mData.begin(), weight_gpu_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(weight,
weight_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename Invoker,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
std::string wei_layout,
std::string out_layout,
ck_tile::ArgParser& arg_parser)
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}

View File

@@ -0,0 +1,297 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
using BiasAndClamp = ck_tile::element_wise::
Compose<ck_tile::element_wise::MultiDAdd, ck_tile::element_wise::Clamp, true>;
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd_bias_clamp(const ck_tile::GroupedConvFwdHostArgs<BiasAndClamp>& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
ck_tile::tuple<OutDataType>,
ck_tile::tuple<OutLayout>,
BiasAndClamp>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const float floor = -100.f;
const float ceil = 100.f;
const ck_tile::element_wise::MultiDAdd bias_op{};
const ck_tile::element_wise::Clamp clamp_op{floor, ceil};
const BiasAndClamp bias_clamp_op{bias_op, clamp_op};
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
ck_tile::HostTensor<OutDataType> bias(out_g_n_k_wos_desc);
std::string bias_str = "";
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{-5.f, 5.f}(bias);
bias_str = " (Uniform(-5,5))";
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(input);
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
ck_tile::FillMonotonicSeq<OutDataType>{}(bias);
bias_str = " (Monotonic)";
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(bias);
bias_str = " (Constant 1)";
}
else
{
input.SetZero();
weight.SetZero();
bias.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_dev_buf(bias.get_element_space_size_in_bytes());
input_dev_buf.ToDevice(input.data());
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
bias_dev_buf.ToDevice(bias.data());
ck_tile::GroupedConvFwdHostArgs<BiasAndClamp> args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{bias_dev_buf.GetDeviceBuffer()},
output_dev_buf.GetDeviceBuffer(),
kbatch,
bias_clamp_op);
std::cout << "Run Grouped Conv Fwd kernel with bias" << bias_str << " and clamp (" << floor
<< ", " << ceil << ")." << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd_bias_clamp<NDimSpatial,
GemmWarpConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
output_dev_buf.FromDevice(output.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
output_host_ref.SetZero();
auto bias_clamp_host = [floor,
ceil](float& y, const float& x, const OutDataType& element_bias) {
float x_float = ck_tile::type_convert<float>(x);
x_float += ck_tile::type_convert<float>(element_bias);
if(x_float < floor)
x_float = floor;
else if(x_float > ceil)
x_float = ceil;
y = x_float;
};
auto bias_tuple = ck_tile::make_tuple(bias);
ck_tile::reference_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
decltype(bias_clamp_host)>(
input,
weight,
output_host_ref,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
bias_clamp_host,
bias_tuple);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(output,
output_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
// GPU verification for fused operation (Conv + Bias + Clamp) is complex
// For now, we only support GPU verification for basic convolution operations
// The bias+clamp fused variant can use CPU verification (-v=1) or no verification (-v=0)
throw std::runtime_error("GPU verification not yet supported for fused operations! Use "
"-v=1 for CPU verification.");
}
return pass;
}
template <typename Invoker,
typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}

View File

@@ -0,0 +1,294 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
int n_warmup,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_fwd_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(input);
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
}
else
{
input.SetZero();
weight.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.ToDevice(input.data());
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
ck_tile::GroupedConvFwdHostArgs<> args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
output_dev_buf.FromDevice(output.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
output_host_ref.SetZero();
ck_tile::reference_grouped_conv_fwd<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input,
weight,
output_host_ref,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(output,
output_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
// GPU reference verification
ck_tile::DeviceMem output_ref_dev_buf(output.get_element_space_size_in_bytes());
output_ref_dev_buf.SetZero();
// GPU reference uses conv_param vectors directly (they are already long_index_t)
// Launch GPU reference kernel
std::cout << "Run GPU reference kernel..." << std::endl;
ck_tile::naive_grouped_conv_fwd<NDimSpatial, InDataType, WeiDataType, OutDataType>(
reinterpret_cast<const InDataType*>(input_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const WeiDataType*>(weight_dev_buf.GetDeviceBuffer()),
reinterpret_cast<OutDataType*>(output_ref_dev_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.output_spatial_lengths_,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_);
// Copy GPU reference result to host for comparison
ck_tile::HostTensor<OutDataType> output_gpu_ref(out_g_n_k_wos_desc);
output_ref_dev_buf.FromDevice(output_gpu_ref.data());
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(output_gpu_ref.mData.begin(), output_gpu_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(output,
output_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename Invoker,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}