mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Merge commit '6c2ca1211ae29802281049843d284ba1bd6511f8' into develop
This commit is contained in:
@@ -4,6 +4,9 @@ list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_forward_invoker.hpp"
|
||||
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
|
||||
{
|
||||
using Invoker = GroupedConvolutionForwardInvoker;
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_bias_clamp_example<GemmConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
@@ -15,10 +15,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
@@ -49,7 +49,8 @@ struct GroupedConvolutionForwardInvoker
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
VectorSizeC,
|
||||
CDElementWise>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
@@ -128,7 +129,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
|
||||
@@ -0,0 +1,301 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
using BiasAndClamp = ck_tile::element_wise::
|
||||
Compose<ck_tile::element_wise::MultiDAdd, ck_tile::element_wise::Clamp, true>;
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd_bias_clamp(const ck_tile::GroupedConvFwdHostArgs<BiasAndClamp>& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ck_tile::tuple<OutDataType>,
|
||||
ck_tile::tuple<OutLayout>,
|
||||
BiasAndClamp>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const float floor = -100.f;
|
||||
const float ceil = 100.f;
|
||||
|
||||
const ck_tile::element_wise::MultiDAdd bias_op{};
|
||||
const ck_tile::element_wise::Clamp clamp_op{floor, ceil};
|
||||
const BiasAndClamp bias_clamp_op{bias_op, clamp_op};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
ck_tile::HostTensor<OutDataType> bias(out_g_n_k_wos_desc);
|
||||
|
||||
std::string bias_str = "";
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{-5.f, 5.f}(bias);
|
||||
bias_str = " (Uniform(-5,5))";
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
|
||||
ck_tile::FillMonotonicSeq<OutDataType>{}(bias);
|
||||
bias_str = " (Monotonic)";
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(bias);
|
||||
bias_str = " (Constant 1)";
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
weight.SetZero();
|
||||
bias.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_dev_buf(bias.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
bias_dev_buf.ToDevice(bias.data());
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs<BiasAndClamp> args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{bias_dev_buf.GetDeviceBuffer()},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
bias_clamp_op);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel with bias" << bias_str << " and clamp (" << floor
|
||||
<< ", " << ceil << ")." << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd_bias_clamp<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
output_dev_buf.FromDevice(output.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
// FIXME: Address this issue
|
||||
if(arg_parser.get_int("g") > 1 && init_method == 0)
|
||||
std::cerr << "Adding different bias to different groups yield incorrect results"
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
auto bias_clamp_host = [floor,
|
||||
ceil](float& y, const float& x, const OutDataType& element_bias) {
|
||||
float x_float = ck_tile::type_convert<float>(x);
|
||||
x_float += ck_tile::type_convert<float>(element_bias);
|
||||
if(x_float < floor)
|
||||
x_float = floor;
|
||||
else if(x_float > ceil)
|
||||
x_float = ceil;
|
||||
y = x_float;
|
||||
};
|
||||
auto bias_tuple = ck_tile::make_tuple(bias);
|
||||
ck_tile::reference_grouped_conv_fwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
decltype(bias_clamp_host)>(
|
||||
input,
|
||||
weight,
|
||||
output_host_ref,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
bias_clamp_host,
|
||||
bias_tuple);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(output,
|
||||
output_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
|
||||
throw std::runtime_error("1D Convolution does not support bias.");
|
||||
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
// GemmWarpConfig,
|
||||
// Invoker,
|
||||
// InPrecType,
|
||||
// WeiPrecType,
|
||||
// OutPrecType>(
|
||||
// argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
@@ -128,12 +128,12 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.SetZero();
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
ck_tile::GroupedConvFwdHostArgs<> args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
|
||||
@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
|
||||
|
||||
To enable the experimental builder, configure your build with:
|
||||
|
||||
```sh
|
||||
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
|
||||
```bash
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942;gfx950" \
|
||||
-D CK_EXPERIMENTAL_BUILDER=ON \
|
||||
-D CMAKE_CXX_STANDARD=20 \
|
||||
-G Ninja \
|
||||
..
|
||||
```
|
||||
|
||||
## Building and testing
|
||||
|
||||
During development, build and test from the CK build directory with
|
||||
|
||||
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Convert a static array to a sequence
|
||||
// Usage example:
|
||||
// static constexpr std::vector arr {1, 2, 3};
|
||||
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
|
||||
template <typename T, const T& Arr>
|
||||
struct to_sequence_t
|
||||
{
|
||||
private:
|
||||
template <std::size_t... Is>
|
||||
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
|
||||
|
||||
// Helper method to handler the unusual .Size() method name in ck::Array.
|
||||
static constexpr auto get_size(const auto& arr)
|
||||
{
|
||||
if constexpr(requires { arr.size(); })
|
||||
{
|
||||
return arr.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return arr.Size();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
|
||||
};
|
||||
|
||||
template <auto& Arr>
|
||||
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
|
||||
|
||||
// Wrapper function to make constexpr strings a structural type for NTTP.
|
||||
template <size_t N>
|
||||
struct StringLiteral
|
||||
{
|
||||
char data[N];
|
||||
constexpr StringLiteral(const char (&str)[N])
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
data[i] = str[i];
|
||||
}
|
||||
|
||||
constexpr bool operator==(const StringLiteral<N>& other) const
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(data[i] != other.data[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// This is a C++17 deduction guide. It allows the compiler to automatically
|
||||
// deduce the template argument `N` for `StringLiteral` from a string literal
|
||||
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
|
||||
// instead of `StringLiteral<4> s{"foo"};`.
|
||||
template <size_t N>
|
||||
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
|
||||
|
||||
// Helper to provide a readable error for unsupported enum values.
|
||||
// The compiler will print the name of this struct in the error message, so
|
||||
// the name of the enum value will appear instead of just its integer value.
|
||||
template <auto T>
|
||||
struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <array>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/********************************************************************/
|
||||
/* Descriptors for individual elements of the algorithm description */
|
||||
/********************************************************************/
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
template <typename T>
|
||||
concept ThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<size_t>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_block } -> std::convertible_to<size_t>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the LDS transfer for the convolution input tensors.
|
||||
template <typename T>
|
||||
concept LdsTransferDescriptor = requires(T t) {
|
||||
{ t.src_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.is_direct_load } -> std::convertible_to<bool>;
|
||||
{ t.lds_padding } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
/******************************************** */
|
||||
/* Requirements for the algorithm description */
|
||||
/******************************************** */
|
||||
|
||||
// Concept to check if struct specifies thread block info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block_gemm_pipeline_version.
|
||||
template <typename T>
|
||||
concept SpecifiesGemmPipelineVersion = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Limits for input vector transfer.
|
||||
template <auto Value>
|
||||
concept InputVectorTransferLimits = requires {
|
||||
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
|
||||
Value.n_xdl_per_wave_per_shuffle > 0;
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/conv_factory.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
|
||||
* the algorithm (e.g., data types, layouts, direction).
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam VERSION The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION = LATEST_API_VERSION>
|
||||
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
// Output: The kernel class.
|
||||
using Instance = Factory::Instance;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
@@ -0,0 +1,539 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// A factory for instantiating CK convolution kernels.
|
||||
//
|
||||
// This file translates a semantic description of a convolution operation
|
||||
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
|
||||
// low-level template arguments required by the underlying CK device-level
|
||||
// kernel implementations. This abstraction enables more complex build
|
||||
// time logic and simplifies the kernel specification.
|
||||
//
|
||||
// Key Components:
|
||||
//
|
||||
// Template Metaprogram:
|
||||
// - ConvFactory: The main factory, with specializations for different
|
||||
// convolution directions (currently only forward).
|
||||
//
|
||||
// Template Metaprogram Helpers:
|
||||
// - ConvTensorLayouts: Maps layout enums to CK layout types for different
|
||||
// spatial dimensions (2D/3D) and directions.
|
||||
// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK.
|
||||
// - ConvPassThroughOps: Hard-coded pass-through element-wise operations.
|
||||
// - ConvSpec: Encapsulates convolution and GEMM specialization enums.
|
||||
//
|
||||
// `constexpr` Helper Functions:
|
||||
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
|
||||
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
|
||||
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
|
||||
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
|
||||
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
|
||||
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
|
||||
//
|
||||
// The primary entry point is the `ConvFactory` struct, which is currently
|
||||
// specialized for forward convolutions and produces instances of
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
|
||||
struct ConvSpec
|
||||
{
|
||||
CONV_ENUM conv_spec;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec;
|
||||
};
|
||||
|
||||
// Deduction guide for ConvSpec to simplify brace initialization.
|
||||
template <typename CONV_ENUM, typename GEMM_ENUM>
|
||||
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
|
||||
|
||||
// Block info for a convolution.
|
||||
struct MNK
|
||||
{
|
||||
size_t m{};
|
||||
size_t n{};
|
||||
size_t k{};
|
||||
};
|
||||
struct ConvBlock
|
||||
{
|
||||
size_t block_size = 0;
|
||||
MNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}};
|
||||
}
|
||||
|
||||
// Convolution tuning parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr GridwiseGemm SetGridwiseGemmInfo()
|
||||
{
|
||||
constexpr auto& TP = ALGORITHM.gridwise_gemm;
|
||||
return GridwiseGemm{
|
||||
.ak1 = TP.ak1,
|
||||
.bk1 = TP.bk1,
|
||||
.m_per_xdl = TP.m_per_xdl,
|
||||
.n_per_xdl = TP.n_per_xdl,
|
||||
.m_xdl_per_wave = TP.m_xdl_per_wave,
|
||||
.n_xdl_per_wave = TP.n_xdl_per_wave,
|
||||
};
|
||||
}
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvABlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvBBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle = 0;
|
||||
size_t n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
size_t scalar_per_vector = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
|
||||
CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
TCL.m_block,
|
||||
TCL.m_wave_per_xdl,
|
||||
TCL.n_block,
|
||||
TCL.n_wave_per_xdl,
|
||||
},
|
||||
.scalar_per_vector = EPC.scalar_per_vector};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory_internal
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
|
||||
// Factory specialization for an instance of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts =
|
||||
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
// Check preconditions for the algorithm description.
|
||||
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
|
||||
"Only 2D and 3D convolutions are supported in this factory.");
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesGemmPipelineVersion<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block gemm pipeline version.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{
|
||||
.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM =
|
||||
factory_internal::SetGridwiseGemmInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION =
|
||||
factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This file defines the compile-time "signature" for grouped convolution operations.
|
||||
// A signature is a collection of properties that fully describe a convolution kernel's
|
||||
// mathematical characteristics. It uses C++20 concepts and enums to specify these
|
||||
// properties, enabling compile-time validation and specialization.
|
||||
//
|
||||
// The core components of a signature are:
|
||||
// - Spatial dimensionality (1D, 2D, 3D)
|
||||
// - Operational direction (Forward, Backward Data, Backward Weight)
|
||||
// - Tensor memory layout (Channels First/Last)
|
||||
// - Data type (FP32, FP16, BF16)
|
||||
// - Fused element-wise operation (e.g., Bias, Clamp)
|
||||
//
|
||||
// The file also provides predicate concepts to query the properties of a given
|
||||
// signature at compile time.
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Constraints for forward convolution layouts.
|
||||
template <auto LayoutValue, size_t SpatialDim>
|
||||
concept ValidConvLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
|
||||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
|
||||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
|
||||
|
||||
// Constrains convolution data types to common floating-point types.
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
@@ -62,7 +63,9 @@ consteval std::string_view type_name()
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
|
||||
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
FP8,
|
||||
I8,
|
||||
U8
|
||||
};
|
||||
|
||||
// Memory layouts for 1D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout1D
|
||||
{
|
||||
GNWC_GKXC_GNWK,
|
||||
NWGC_GKXC_NWGK,
|
||||
NGCW_GKXC_NGKW,
|
||||
NGCW_GKCX_NGKW
|
||||
};
|
||||
|
||||
// Memory layouts for 2D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout2D
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK,
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKYXC_NGKHW,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
// Memory layouts for 3D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
|
||||
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout3D
|
||||
{
|
||||
GNDHWC_GKZYXC_GNDHWK,
|
||||
NDHWGC_GKZYXC_NDHWGK,
|
||||
NGCDHW_GKZYXC_NGKDHW,
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
static constexpr StringLiteral V0_0_0 = "0.0.0";
|
||||
static constexpr StringLiteral V0_1_0 = "0.1.0";
|
||||
|
||||
static constexpr StringLiteral LATEST_API_VERSION = V0_1_0;
|
||||
|
||||
template <StringLiteral V>
|
||||
concept SupportedVersion = (V == V0_0_0) || (V == V0_1_0);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -7,6 +7,7 @@ function(add_ck_builder_test test_name)
|
||||
target_include_directories(${test_name} PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
|
||||
"${PROJECT_SOURCE_DIR}/include"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
)
|
||||
target_compile_options(${test_name} PRIVATE
|
||||
-Wno-global-constructors
|
||||
@@ -24,3 +25,11 @@ add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST_F(FwdConv3DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
size_t block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise GEMM parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
{
|
||||
size_t m_block;
|
||||
size_t m_wave_per_xdl;
|
||||
size_t n_block;
|
||||
size_t n_wave_per_xdl;
|
||||
};
|
||||
static_assert(ThreadClusterDescriptor<ThreadCluster>);
|
||||
|
||||
struct LdsTransfer
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scalar_per_vector;
|
||||
size_t lds_dst_scalar_per_vector;
|
||||
bool is_direct_load;
|
||||
bool lds_padding;
|
||||
};
|
||||
static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
|
||||
struct BlockTransferABC
|
||||
{
|
||||
BlockTransfer block_transfer_a;
|
||||
BlockTransfer block_transfer_b;
|
||||
ThreadCluster thread_cluster_dims_c;
|
||||
LdsTransfer lds_transfer_a;
|
||||
LdsTransfer lds_transfer_b;
|
||||
Epilogue epilogue_c;
|
||||
AccessOrder block_transfer_access_order_a;
|
||||
AccessOrder block_transfer_access_order_b;
|
||||
AccessOrder src_access_order_a;
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <typename GroupConvLayout>
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
ConvDirection direction;
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test base class
|
||||
class FwdConvBuilderTestBase : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test()
|
||||
{
|
||||
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = FwdPipelineVersion,
|
||||
.fwd_specialization = FwdConvSpecialization};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Common thread block configurations
|
||||
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
@@ -552,6 +552,8 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<bf8_t>(x);
|
||||
}
|
||||
|
||||
static constexpr const char* name = "PassThrough";
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -14,14 +15,18 @@ namespace ck_tile {
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
typename OutDataType,
|
||||
typename Elfunc = ck_tile::element_wise::PassThrough,
|
||||
typename Tuple = ck_tile::tuple<>>
|
||||
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
std::vector<ck_tile::long_index_t>,
|
||||
Elfunc elfunc = Elfunc{},
|
||||
Tuple ds = {})
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
@@ -52,8 +57,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
@@ -95,8 +104,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
@@ -145,8 +158,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_converted;
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
|
||||
@@ -1540,6 +1540,23 @@ struct Logistic
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
|
||||
float upper = std::numeric_limits<float>::max())
|
||||
: lower_(lower), upper_(upper) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
|
||||
{
|
||||
T lower = ck_tile::type_convert<T>(lower_);
|
||||
T upper = ck_tile::type_convert<T>(upper_);
|
||||
y = ck_tile::clamp(x, lower, upper);
|
||||
}
|
||||
|
||||
float lower_, upper_;
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
@@ -1629,6 +1646,55 @@ struct Cast
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compose two unary element-wise functions into one.
|
||||
*
|
||||
*
|
||||
* @note The Ds tensor can be used by at most one of the composed functions.
|
||||
* This holds even if compositions are chained:
|
||||
* In `Compose<FA, Compose<FB, FC>>`, only one of `FA`, `FB`, or `FC` can use
|
||||
* the Ds tensor.
|
||||
*
|
||||
* @tparam FuncA The first function to be applied.
|
||||
* @tparam FuncB The second function to be applied.
|
||||
* @tparam FuncADs Whether `FuncA` uses the Ds tensor.
|
||||
* @tparam FuncBDs Whether `FuncB` uses the Ds tensor.
|
||||
*/
|
||||
template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
|
||||
struct Compose
|
||||
{
|
||||
static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
|
||||
|
||||
CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
|
||||
: func_a(func_a_), func_b(func_b_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
|
||||
{
|
||||
AOut tmp;
|
||||
if constexpr(FuncADs)
|
||||
{
|
||||
func_a(tmp, x, ds...);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
else if constexpr(FuncBDs)
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp, ds...);
|
||||
}
|
||||
else
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const FuncA func_a;
|
||||
const FuncB func_b;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
#if 0
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -117,6 +117,10 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
@@ -385,7 +389,7 @@ struct CShuffleEpilogue
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
|
||||
}
|
||||
|
||||
template <typename OutDramWindow, typename COutTensor>
|
||||
@@ -450,7 +454,7 @@ struct CShuffleEpilogue
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /*p_smem*/,
|
||||
void* /* p_smem */,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
@@ -28,6 +30,7 @@ struct GroupedConvFwdKernelArgs
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
true>; // Split N enabled
|
||||
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -38,7 +41,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -121,7 +125,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -213,7 +218,8 @@ struct GroupedConvFwdKernelArgs
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
|
||||
: elfunc(args.elfunc)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
@@ -335,6 +341,7 @@ struct GroupedConvFwdKernelArgs
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
const CDElementwise elfunc;
|
||||
void* out_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
@@ -423,6 +430,8 @@ struct GroupedConvolutionForwardKernel
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using CDElementwise = typename EpiloguePipeline::CDElementwise;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
@@ -458,7 +467,7 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
@@ -636,7 +645,7 @@ struct GroupedConvolutionForwardKernel
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -765,8 +774,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{kargs.elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,7 +15,7 @@ namespace ck_tile {
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr>
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_)
|
||||
index_t k_batch_,
|
||||
CDElementwise elfunc_ = CDElementwise{})
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
k_batch(k_batch_),
|
||||
elfunc(elfunc_)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
const std::vector<const void*> ds_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
const CDElementwise elfunc;
|
||||
};
|
||||
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
|
||||
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
|
||||
using PassThrough = ck_tile::element_wise::PassThrough;
|
||||
|
||||
template <typename CDElementwise = PassThrough>
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
|
||||
using GroupedConvBwdWeightHostArgs =
|
||||
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
|
||||
using GroupedConvBwdDataHostArgs =
|
||||
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1>
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -70,6 +80,7 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using CDElementwise = CDElementwise_;
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
|
||||
@@ -33,6 +33,14 @@ struct elementwise_op_traits<ck_tile::element_wise::Relu>
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
using NegRelu =
|
||||
ck_tile::element_wise::Compose<ck_tile::element_wise::Relu, ck_tile::element_wise::Neg>;
|
||||
template <>
|
||||
struct elementwise_op_traits<NegRelu>
|
||||
{
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
template <std::size_t D, typename F>
|
||||
auto make_uniform_array_with_factory(F&& factory)
|
||||
{
|
||||
@@ -194,7 +202,11 @@ using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
|
||||
using TestConfig_F32_Neg_Relu =
|
||||
std::tuple<float, float, float, NegRelu, Shape1_BlockWarps, Shape1_BlockTile, Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::
|
||||
Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add, TestConfig_F32_Neg_Relu>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user