mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] Add conv bwd weight two stage support (#2855)
* resolved conflicts
* add conv bwd weight twostage
* fix one file
* fixes after review
* fixes
* fixes
* Fix
---------
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
[ROCm/composable_kernel commit: 624c46866e]
This commit is contained in:
@@ -7,5 +7,8 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -41,8 +41,8 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
@@ -51,20 +51,29 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
@@ -90,7 +99,7 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
|
||||
@@ -11,195 +11,13 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "grouped_convolution_backward_weight_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
using Invoker = GroupedConvolutionBackwardWeightInvoker;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
@@ -208,13 +26,17 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -224,9 +46,22 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
struct GroupedConvolutionBackwardWeightInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeC = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum *
|
||||
sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto kargs = Kernel::MakeKernelArgs(ws_args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceDataType,
|
||||
WorkspaceDataType,
|
||||
WeiDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(args.G_ * args.K_),
|
||||
static_cast<ck_tile::index_t>(args.C_ * spatial_lengths_accum)};
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize =
|
||||
(total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
auto input_tensors =
|
||||
ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -50,20 +50,29 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
@@ -89,7 +98,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -15,43 +16,34 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
const InLayout,
|
||||
const WeiLayout,
|
||||
const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
@@ -138,17 +130,27 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
@@ -189,3 +191,61 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
std::string wei_layout,
|
||||
std::string out_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
arg_parser, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
arg_parser, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
arg_parser, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -455,7 +455,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
|
||||
@@ -24,7 +24,10 @@ struct GroupedConvBwdDataKernelArgs
|
||||
|
||||
using ConvToGemmTransformer =
|
||||
TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization>;
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -468,6 +471,10 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardDataKernel
|
||||
{
|
||||
// Todo: Enable Vector Load Size > 1
|
||||
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
|
||||
GroupedConvTraitsType_::VectorSizeB == 1);
|
||||
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
@@ -509,10 +516,13 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported A GEMM layout!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"Not supported B GEMM layout!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
// TODO: Change to and enable vector load
|
||||
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Not supported A GEMM layout!");
|
||||
// static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Not supported B GEMM layout!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported C GEMM layout!");
|
||||
|
||||
@@ -548,7 +558,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
@@ -625,7 +635,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
@@ -637,13 +647,12 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
@@ -655,12 +664,11 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
@@ -957,7 +965,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
@@ -975,7 +983,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
|
||||
|
||||
@@ -23,7 +23,10 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
|
||||
using ConvToGemmTransformer =
|
||||
TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization>;
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -335,6 +338,10 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
// Todo: Enable Vector Load Size > 1
|
||||
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
|
||||
GroupedConvTraitsType_::VectorSizeB == 1);
|
||||
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
@@ -355,11 +362,10 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
|
||||
@@ -376,6 +382,10 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
// TODO: Change to and enable vector load
|
||||
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not
|
||||
// supported!"); static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not
|
||||
// supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -453,8 +463,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -525,7 +535,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
@@ -537,13 +547,11 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
@@ -555,12 +563,11 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
@@ -596,9 +603,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}();
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
kargs.c_grid_desc_m_n); // B: in
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
|
||||
kargs.c_grid_desc_m_n);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -607,11 +613,11 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -829,8 +835,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
@@ -848,8 +854,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
|
||||
|
||||
@@ -24,6 +24,9 @@ struct GroupedConvFwdKernelArgs
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
true>; // Split N enabled
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
@@ -467,7 +470,7 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
@@ -550,7 +553,7 @@ struct GroupedConvolutionForwardKernel
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
@@ -568,7 +571,7 @@ struct GroupedConvolutionForwardKernel
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
@@ -585,7 +588,7 @@ struct GroupedConvolutionForwardKernel
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
@@ -858,7 +861,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
@@ -868,7 +871,7 @@ struct GroupedConvolutionForwardKernel
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
|
||||
@@ -49,7 +49,10 @@ template <index_t NDimSpatial_,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_>
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -67,14 +70,38 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdData =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -442,14 +445,17 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_),
|
||||
make_tuple(NStride, WoStride, KStride));
|
||||
make_tuple(NStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, X_, C_));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number<VectorSizeB>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -462,7 +468,9 @@ struct TransformConvBwdDataToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride));
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -477,7 +485,9 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_),
|
||||
make_tuple(NStride, HoStride, WoStride, KStride));
|
||||
make_tuple(NStride, HoStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -491,14 +501,19 @@ struct TransformConvBwdDataToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKYXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -514,7 +529,9 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, K_),
|
||||
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
|
||||
make_tuple(NStride, DoStride, HoStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -529,14 +546,20 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKZYXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Z_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -420,7 +423,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -433,7 +438,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride));
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -444,7 +451,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto CXStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -457,7 +465,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -471,7 +481,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -482,8 +494,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -496,7 +508,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -511,7 +525,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -523,7 +539,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -446,7 +449,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
@@ -458,7 +463,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -473,8 +480,11 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
|
||||
const auto in_n_wi_c_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Wi_),
|
||||
make_tuple(NStrideTensorA_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -502,7 +512,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -535,7 +547,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -556,7 +570,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -581,7 +597,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -611,7 +629,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -661,7 +681,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
@@ -675,7 +697,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -689,8 +713,11 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -721,7 +748,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -757,7 +786,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -780,7 +811,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -808,7 +841,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -843,7 +878,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -904,7 +941,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
@@ -922,7 +961,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -939,7 +980,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -975,7 +1018,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1022,7 +1067,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1052,7 +1099,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1090,7 +1139,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1138,7 +1189,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1217,14 +1270,19 @@ struct TransformConvFwdToGemm
|
||||
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, FilterSizeNumType{}),
|
||||
make_tuple(FilterSizeNumType{}, I1),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
@@ -1237,13 +1295,18 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, ZYX_ * C_),
|
||||
make_tuple(ZYX_ * C_, I1),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
@@ -1270,14 +1333,18 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
@@ -1328,7 +1395,9 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1339,7 +1408,9 @@ struct TransformConvFwdToGemm
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
@@ -1390,7 +1461,9 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1402,7 +1475,9 @@ struct TransformConvFwdToGemm
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
|
||||
Reference in New Issue
Block a user