mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Grouped Convolution Backward Weight Kernel (#2357)
* [CK TILE] Grouped Convolution Forward Kernel
* custom vector size
* fixes
* refactor
* resolved conflicts
* rebase fixes
* fixes
* tmp
* add working support for splitk
* minor fix
* fixes
* fixes
* minor fix
* small fix
* Split K and preprocessing fixes
---------
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
[ROCm/composable_kernel commit: 6681593864]
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
@@ -23,7 +23,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -97,7 +97,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -129,7 +129,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_example.inc"
|
||||
#include "run_grouped_convolution_fwd_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_example_prec_type(
|
||||
@@ -185,7 +185,7 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("weight_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
|
||||
@@ -12,6 +12,28 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& image_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& strides,
|
||||
@@ -90,7 +112,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("rpad_w", "0", "right pad for w dimension")
|
||||
|
||||
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
|
||||
.insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("wei_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
@@ -105,4 +127,5 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
output.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.SetZero();
|
||||
output_dev_buf.ToDevice(output.data());
|
||||
|
||||
ck_tile::GroupedConvBwdWeightHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<WeiDataType> weight_host_ref(wei_g_k_c_xs_desc);
|
||||
weight_host_ref.SetZero();
|
||||
|
||||
ck_tile::
|
||||
reference_grouped_conv_bwd_weight<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight_host_ref,
|
||||
output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK =
|
||||
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(weight,
|
||||
weight_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -2,28 +2,6 @@
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -32,7 +10,9 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
@@ -143,12 +123,12 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
ck_tile::GroupedConvFwdHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_fused_moe.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
|
||||
HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, do_, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, z, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4],
|
||||
weight.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,12 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -0,0 +1,861 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType>
|
||||
struct GroupedConvBwdWeightKernelArgs
|
||||
{
|
||||
|
||||
using ConvToGemmTransformer =
|
||||
TransformConvBwdWeightToGemm<GroupedConvTraitsType::NDimSpatial,
|
||||
GroupedConvTraitsType::ConvSpecialization>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NWGK
|
||||
group_stride_b = args.C_; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NHWGK
|
||||
group_stride_b = args.C_; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NDHWGK
|
||||
group_stride_b = args.C_; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<decltype(
|
||||
ConvToGemmTransformer{}.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(ABCGridDescs{}[number<2>{}])>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> wei_g_k_c_xs_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> out_g_n_k_wos_lengths;
|
||||
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_strides;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_dilations;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_left_pads;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_right_pads;
|
||||
|
||||
index_t k_batch;
|
||||
index_t GemmM;
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* wei_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
BGridDescNK b_grid_desc_n_k;
|
||||
CGridDescMN c_grid_desc_m_n;
|
||||
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the grouped convolution forward kernel template. By semantic
|
||||
/// division of Implicit GEMM algorithm into following parts we achieve flexible,
|
||||
/// versatile and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// tparam ConvSpecialization Tensor descriptors specialization.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of
|
||||
/// data loading from global memory and performing block-wise
|
||||
/// matrix multiplication. You can think of it as a work done by
|
||||
/// single workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using InLayout = remove_cvref_t<typename GroupedConvTraitsType::InLayout>;
|
||||
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType::WeiLayout>;
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = true;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
{
|
||||
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead =
|
||||
__builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1);
|
||||
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k =
|
||||
__builtin_amdgcn_readfirstlane(kargs.GemmK - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const stream_config& s)
|
||||
{
|
||||
return [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
|
||||
sizeof(WeiDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
// check ConvSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr,
|
||||
kargs.a_grid_desc_m_k); // A: out
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr,
|
||||
kargs.b_grid_desc_n_k); // B: in
|
||||
}();
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.GemmM, kargs.GemmN),
|
||||
make_tuple(kargs.GemmN, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{} * k_batch),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{} * k_batch),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
|
||||
auto gemm_tile_windows =
|
||||
MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
|
||||
{
|
||||
const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
|
||||
const index_t i_k =
|
||||
__builtin_amdgcn_readfirstlane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
|
||||
|
||||
const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
// conv_bwd_weight = Out * In = Weight
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
num_loop,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -34,7 +34,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -56,9 +56,10 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
@@ -103,7 +104,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -132,9 +133,10 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
@@ -179,7 +181,7 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -220,6 +222,7 @@ struct GroupedConvFwdKernelArgs
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
|
||||
args.filter_spatial_lengths_[2];
|
||||
GemmBatch = args.G_;
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
@@ -280,6 +283,7 @@ struct GroupedConvFwdKernelArgs
|
||||
index_t GemmM;
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
@@ -354,8 +358,7 @@ struct GroupedConvolutionForwardKernel
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
@@ -389,20 +392,16 @@ struct GroupedConvolutionForwardKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
const index_t GemmN = args.K_;
|
||||
return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch);
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvHostArgs& hostArgs)
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
@@ -750,7 +749,7 @@ struct GroupedConvolutionForwardKernel
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1);
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
|
||||
@@ -14,14 +14,15 @@ namespace ck_tile {
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
|
||||
const void* in_ptr_,
|
||||
const void* wei_ptr_,
|
||||
InPtr in_ptr_,
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
void* out_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_)
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
@@ -32,13 +33,16 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
}
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
InPtr in_ptr;
|
||||
WeiPtr wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
void* out_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
typename InLayout_,
|
||||
@@ -55,6 +59,7 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t NumGroupsToMerge = 1;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
private:
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr auto I4 = number<4>{};
|
||||
static constexpr auto I5 = number<5>{};
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
template <typename ConvDimsType>
|
||||
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
|
||||
const ConvDimsType& strides,
|
||||
index_t i)
|
||||
{
|
||||
long_index_t acc = 1;
|
||||
for(; i < (NDimSpatial + 3); i++)
|
||||
{
|
||||
acc +=
|
||||
static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
{
|
||||
const long_index_t a_element_space_size =
|
||||
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
|
||||
const long_index_t c_element_space_size =
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if(element_space_size > TwoGB)
|
||||
{
|
||||
// Minimum divisor of N to not exceed 2GB
|
||||
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
|
||||
|
||||
if(divisor <= static_cast<double>(N))
|
||||
{
|
||||
// Find least divisor of N larger than element_space_size / TwoGB
|
||||
// Iterate up to sqrt(N). There are no divisors above this value.
|
||||
for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
|
||||
least_divisor++)
|
||||
{
|
||||
if(N % least_divisor == 0)
|
||||
{
|
||||
return N / least_divisor;
|
||||
}
|
||||
}
|
||||
// Not found, process one Convolution N per block
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Split Convolution's N dimension into N workgroups. However
|
||||
// this still might not result in sufficiently small tensor,
|
||||
// but at least later on we could divide the image as well.
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Split N is not needed.
|
||||
return N;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
CK_TILE_HOST constexpr TransformConvBwdWeightToGemm() {}
|
||||
|
||||
template <typename TransformConvBwdWeightToGemmBase>
|
||||
CK_TILE_HOST TransformConvBwdWeightToGemm(
|
||||
const TransformConvBwdWeightToGemmBase& transform_conv_fwd_to_gemm_base)
|
||||
: G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
|
||||
N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
|
||||
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
|
||||
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
|
||||
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
|
||||
Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
|
||||
Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
|
||||
Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
|
||||
Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
|
||||
Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
|
||||
X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
|
||||
K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
|
||||
C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
|
||||
ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
|
||||
ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
|
||||
ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
|
||||
ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
|
||||
ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
|
||||
ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
|
||||
InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
|
||||
InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
typename ConvSpatialDimsType,
|
||||
index_t NDim = NDimSpatial,
|
||||
typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& b_g_k_c_xs_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvSpatialDimsType& conv_filter_strides,
|
||||
const ConvSpatialDimsType& conv_filter_dilations,
|
||||
const ConvSpatialDimsType& input_left_pads,
|
||||
const ConvSpatialDimsType& input_right_pads)
|
||||
: G_{a_g_n_c_wis_lengths[I0]},
|
||||
Di_{I1},
|
||||
Hi_{I1},
|
||||
Wi_{a_g_n_c_wis_lengths[I3]},
|
||||
Do_{I1},
|
||||
Ho_{I1},
|
||||
Wo_{c_g_n_k_wos_lengths[I3]},
|
||||
Z_{I1},
|
||||
Y_{I1},
|
||||
X_{b_g_k_c_xs_lengths[I3]},
|
||||
K_{c_g_n_k_wos_lengths[I2]},
|
||||
C_{b_g_k_c_xs_lengths[I2]},
|
||||
ConvStrideD_{I1},
|
||||
ConvStrideH_{I1},
|
||||
ConvStrideW_{conv_filter_strides[I0]},
|
||||
ConvDilationD_{I1},
|
||||
ConvDilationH_{I1},
|
||||
ConvDilationW_{conv_filter_dilations[I0]},
|
||||
InLeftPadD_{I0},
|
||||
InLeftPadH_{I0},
|
||||
InLeftPadW_{input_left_pads[I0]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{I0},
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
typename ConvSpatialDimsType,
|
||||
index_t NDim = NDimSpatial,
|
||||
typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& b_g_k_c_xs_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvSpatialDimsType& conv_filter_strides,
|
||||
const ConvSpatialDimsType& conv_filter_dilations,
|
||||
const ConvSpatialDimsType& input_left_pads,
|
||||
const ConvSpatialDimsType& input_right_pads)
|
||||
: G_{a_g_n_c_wis_lengths[I0]},
|
||||
Di_{I1},
|
||||
Hi_{a_g_n_c_wis_lengths[I3]},
|
||||
Wi_{a_g_n_c_wis_lengths[I4]},
|
||||
Do_{I1},
|
||||
Ho_{c_g_n_k_wos_lengths[I3]},
|
||||
Wo_{c_g_n_k_wos_lengths[I4]},
|
||||
Z_{I1},
|
||||
Y_{b_g_k_c_xs_lengths[I3]},
|
||||
X_{b_g_k_c_xs_lengths[I4]},
|
||||
K_{c_g_n_k_wos_lengths[I2]},
|
||||
C_{b_g_k_c_xs_lengths[I2]},
|
||||
ConvStrideD_{I1},
|
||||
ConvStrideH_{conv_filter_strides[I0]},
|
||||
ConvStrideW_{conv_filter_strides[I1]},
|
||||
ConvDilationD_{I1},
|
||||
ConvDilationH_{conv_filter_dilations[I0]},
|
||||
ConvDilationW_{conv_filter_dilations[I1]},
|
||||
InLeftPadD_{I0},
|
||||
InLeftPadH_{input_left_pads[I0]},
|
||||
InLeftPadW_{input_left_pads[I1]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{input_right_pads[I0]},
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
typename ConvSpatialDimsType,
|
||||
index_t NDim = NDimSpatial,
|
||||
typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& b_g_k_c_xs_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvSpatialDimsType& conv_filter_strides,
|
||||
const ConvSpatialDimsType& conv_filter_dilations,
|
||||
const ConvSpatialDimsType& input_left_pads,
|
||||
const ConvSpatialDimsType& input_right_pads)
|
||||
: G_{a_g_n_c_wis_lengths[I0]},
|
||||
Di_{a_g_n_c_wis_lengths[I3]},
|
||||
Hi_{a_g_n_c_wis_lengths[I4]},
|
||||
Wi_{a_g_n_c_wis_lengths[I5]},
|
||||
Do_{c_g_n_k_wos_lengths[I3]},
|
||||
Ho_{c_g_n_k_wos_lengths[I4]},
|
||||
Wo_{c_g_n_k_wos_lengths[I5]},
|
||||
Z_{b_g_k_c_xs_lengths[I3]},
|
||||
Y_{b_g_k_c_xs_lengths[I4]},
|
||||
X_{b_g_k_c_xs_lengths[I5]},
|
||||
K_{c_g_n_k_wos_lengths[I2]},
|
||||
C_{b_g_k_c_xs_lengths[I2]},
|
||||
ConvStrideD_{conv_filter_strides[I0]},
|
||||
ConvStrideH_{conv_filter_strides[I1]},
|
||||
ConvStrideW_{conv_filter_strides[I2]},
|
||||
ConvDilationD_{conv_filter_dilations[I0]},
|
||||
ConvDilationH_{conv_filter_dilations[I1]},
|
||||
ConvDilationW_{conv_filter_dilations[I2]},
|
||||
InLeftPadD_{input_left_pads[I0]},
|
||||
InLeftPadH_{input_left_pads[I1]},
|
||||
InLeftPadW_{input_left_pads[I2]},
|
||||
InRightPadD_{input_right_pads[I0]},
|
||||
InRightPadH_{input_right_pads[I1]},
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
__host__ bool AreDescriptorsSmallerThan2GB() const
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const long_index_t in_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
|
||||
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
|
||||
const long_index_t out_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
|
||||
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
|
||||
|
||||
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
|
||||
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
|
||||
|
||||
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
|
||||
}
|
||||
|
||||
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
|
||||
CDataType* c_grid_ptr_base) const
|
||||
{
|
||||
// Create copies
|
||||
auto conv_to_gemm_transformer_left = *this;
|
||||
auto conv_to_gemm_transformer_right = *this;
|
||||
IndexType a_right_offset = 0;
|
||||
IndexType c_right_offset = 0;
|
||||
// Calculate real filter size
|
||||
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
|
||||
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
|
||||
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
|
||||
// Calculate start position in input for right tensor
|
||||
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
|
||||
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
|
||||
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
|
||||
// Calculate last position in input for left tensor
|
||||
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
|
||||
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
|
||||
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
|
||||
// Allow to split if whole left padding will be in left tensor and right padding in right
|
||||
// tensor
|
||||
const bool is_possible_to_split_d = Do_ != 1 &&
|
||||
di_right_transformer_start_idx > InLeftPadD_ &&
|
||||
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
|
||||
const bool is_possible_to_split_h = Ho_ != 1 &&
|
||||
hi_right_transformer_start_idx > InLeftPadH_ &&
|
||||
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
|
||||
const bool is_possible_to_split_w = Wo_ != 1 &&
|
||||
wi_right_transformer_start_idx > InLeftPadW_ &&
|
||||
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
|
||||
|
||||
if(is_possible_to_split_d)
|
||||
{
|
||||
// Apply new sizes
|
||||
// Split output on half
|
||||
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
|
||||
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
|
||||
// Assign left padding to left convolution
|
||||
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
|
||||
// Assign right padding to right convolution
|
||||
conv_to_gemm_transformer_left.InRightPadD_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
|
||||
// Calculate new input size
|
||||
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.Di_ =
|
||||
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
|
||||
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
|
||||
;
|
||||
// Calcualte offsets
|
||||
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
|
||||
c_right_offset = (Do_ / 2) * DoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_h)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
|
||||
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadH_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
|
||||
|
||||
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.Hi_ =
|
||||
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
|
||||
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
|
||||
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
|
||||
c_right_offset = (Ho_ / 2) * HoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_w)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
|
||||
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadW_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
|
||||
|
||||
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.Wi_ =
|
||||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
|
||||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
|
||||
|
||||
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
|
||||
c_right_offset = (Wo_ / 2) * WoStride_;
|
||||
}
|
||||
// Return left transform, right transformer, right offset to Input and right offset to
|
||||
// Output
|
||||
return ck_tile::make_tuple(conv_to_gemm_transformer_left,
|
||||
conv_to_gemm_transformer_right,
|
||||
a_grid_ptr_base + a_right_offset,
|
||||
c_grid_ptr_base + c_right_offset);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
// NWGK
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_in_grid_desc() const
|
||||
{
|
||||
// NWGC
|
||||
const index_t NStride = Wi_ * G_ * C_;
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKXC
|
||||
const index_t KStride = X_ * C_;
|
||||
constexpr auto CXStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
// NHWGK
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_in_grid_desc() const
|
||||
{
|
||||
// NHWGC
|
||||
const index_t NStride = Hi_ * Wi_ * G_ * C_;
|
||||
const index_t HiStride = Wi_ * G_ * C_;
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKYXC
|
||||
const index_t KStride = Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
// NDHWGK
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_in_grid_desc() const
|
||||
{
|
||||
const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_;
|
||||
const index_t DiStride = Hi_ * Wi_ * G_ * C_;
|
||||
const index_t HiStride = Wi_ * G_ * C_;
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// KZYXC
|
||||
const index_t KStride = Z_ * Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
}
|
||||
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
|
||||
{
|
||||
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
|
||||
{
|
||||
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
|
||||
{
|
||||
const auto out_grid_desc = make_out_grid_desc<NDimSpatial>();
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
|
||||
IndexType G_, N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
IndexType Z_, Y_, X_;
|
||||
IndexType K_, C_;
|
||||
IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user