[CK_TILE] Add conv bwd weight two stage support (#2855)

* resolved conflicts

* add conv bwd weight twostage

* fix one file

* fixes after review

* fixes

* fixes

* Fix

---------

Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
jakpiase
2025-09-22 15:31:25 +02:00
committed by GitHub
parent 4363a82bd6
commit 624c46866e
16 changed files with 864 additions and 361 deletions

View File

@@ -7,5 +7,8 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -41,8 +41,8 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeA = 1;
constexpr ck_tile::index_t VectorSizeB = 1;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
@@ -51,20 +51,29 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
@@ -90,7 +99,7 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
memory_operation,
1,
true,
VectorSizeC>>;
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,

View File

@@ -11,195 +11,13 @@
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel_time_mask(
s,
Kernel::Preprocess(kargs, s),
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
}
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Invoker = GroupedConvolutionBackwardWeightInvoker;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
@@ -208,13 +26,17 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmWarpConfig,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmWarpConfig,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
@@ -224,9 +46,22 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,145 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 1;
constexpr ck_tile::index_t VectorSizeB = 1;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel_time_mask(
s,
Kernel::Preprocess(kargs, s),
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
}
};

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <typename GemmWarpConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmWarpConfig,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmWarpConfig,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,215 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
using WorkspaceDataType = float;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr ck_tile::index_t VectorSizeA = 1;
constexpr ck_tile::index_t VectorSizeB = 1;
constexpr ck_tile::index_t VectorSizeC = 1;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType, // A: Out
InDataType, // B: In
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
InDataType, // B: In
DsDataType,
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto kargs = Kernel::MakeKernelArgs(ws_args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
WorkspaceDataType,
WeiDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(args.G_ * args.K_),
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize =
(total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors =
ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(args.k_batch > 1)
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s.stream_id_));
};
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
}
};

View File

@@ -50,20 +50,29 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
@@ -89,7 +98,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
memory_operation,
1,
true,
VectorSizeC>>;
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,

View File

@@ -4,6 +4,7 @@
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -15,43 +16,34 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_warmup,
int n_repeat)
{
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename GemmWarpConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_bwd_weight_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_parser,
const InLayout,
const WeiLayout,
const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
@@ -138,17 +130,27 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmWarpConfig,
Invoker,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
weight_dev_buf.FromDevice(weight.data());
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
bool pass = true;
if(arg_parser.get_int("v") == 1)
@@ -189,3 +191,61 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
return pass;
}
template <typename Invoker,
typename GemmWarpConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
std::string wei_layout,
std::string out_layout,
ck_tile::ArgParser& arg_parser)
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
arg_parser, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}