[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)

[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK
 Tile profiler (#5237)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.

## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Ville Pietilä
2026-03-13 01:21:08 +00:00
committed by assistant-librarian[bot]
parent fc2f95620d
commit e2f5ab8000
16 changed files with 336 additions and 50 deletions

View File

@@ -155,6 +155,7 @@ concept TileOptimizationsDescriptor = requires(T t) {
{ t.num_groups_to_merge } -> std::convertible_to<int>;
{ t.split_image } -> std::convertible_to<bool>;
{ t.explicit_gemm } -> std::convertible_to<bool>;
{ t.two_stage } -> std::convertible_to<bool>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
@@ -295,6 +296,7 @@ concept SpecifiesTileOptimizations = requires {
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
{ T::optimizations.split_image } -> std::convertible_to<bool>;
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
{ T::optimizations.two_stage } -> std::convertible_to<bool>;
};
template <typename T>

View File

@@ -8,6 +8,8 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/versions.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
@@ -68,6 +70,10 @@ struct ConvTileFactory
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using ConvOutDataType = std::conditional_t<OPTIMIZATIONS.two_stage,
typename Types::AccDataType,
typename Types::EDataType>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
@@ -103,7 +109,7 @@ struct ConvTileFactory
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
ConvOutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
@@ -126,4 +132,33 @@ struct ConvTileFactory
ConvEpilogue>::Instance;
};
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION = LATEST_API_VERSION>
struct ElementwiseOpTileFactory
{
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using XDataType = Types::AccDataType;
using WorkspaceDataType = Types::AccDataType;
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using YDataType = Types::EDataType;
using BlockTile = ck_tile::sequence<BLOCK.per_block.m * BLOCK.per_block.n>;
using BlockWarps = ck_tile::sequence<BLOCK_GEMM.warps.m * BLOCK_GEMM.warps.n>;
using WarpTile = ck_tile::sequence<BLOCK_GEMM.warp_tile.m * BLOCK_GEMM.warp_tile.n>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
// Conversion from X -> Y.
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
WorkspaceDataType,
YDataType,
ElementwiseShape,
XElementwiseOperation>;
using Instance = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
};
} // namespace ck_tile::builder::factory

View File

@@ -34,6 +34,7 @@ struct TileOptimizations
int num_groups_to_merge = 1;
bool split_image = false;
bool explicit_gemm = false;
bool two_stage = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
@@ -181,7 +182,8 @@ consteval TileOptimizations SetTileOptimizations()
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
.split_image = OPT.split_image,
.explicit_gemm = OPT.explicit_gemm};
.explicit_gemm = OPT.explicit_gemm,
.two_stage = OPT.two_stage};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -91,6 +91,100 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
}
}
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
auto& elementwise_op,
const Args<SIGNATURE>& args,
InDataType* input,
WeiDataType* weight,
OutDataType* output,
const ck_tile::stream_config s_conf)
{
using Conv = std::remove_reference_t<decltype(conv)>;
using ElementwiseOp = std::remove_reference_t<decltype(elementwise_op)>;
using WorkspaceDataType = typename ElementwiseOp::ComputeDataType;
using CDataType = typename ElementwiseOp::YDataType;
using BlockShape = typename ElementwiseOp::Problem::BlockShape;
const auto param = args.to_ck_tile_conv_param();
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
host_args(param, input, weight, {}, output, args.k_batch);
// Set-up for elementwise op kernel.
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(host_args.filter_spatial_lengths_.begin(),
host_args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(host_args.G_ * host_args.K_ * host_args.C_ *
spatial_lengths_accum * sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(host_args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto kargs = Conv::MakeKernelArgs(ws_args);
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(host_args.G_ * host_args.K_),
static_cast<ck_tile::index_t>(host_args.C_ * spatial_lengths_accum)};
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseOp::BlockSize();
constexpr ck_tile::index_t elements_per_block = BlockShape::kBlockM;
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
// Check if the kernel configuration is supported
if(!ElementwiseOp::IsSupportedArgument(input_size))
{
return RunResult::not_supported("unsupported ck_tile arguments for elementwise op");
}
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s_conf.stream_id_));
}
}
};
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs),
ck_tile::make_kernel<minimum_occupancy>(elementwise_op,
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr))));
}
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
@@ -149,4 +243,28 @@ template <auto SIGNATURE>
s_conf);
}
/// @brief `run()` specialization for two-stage backwards weight convolution and CK Tile.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
auto& elementwise_op,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config s_conf = {})
{
return detail::run(conv,
elementwise_op,
args,
static_cast<const void*>(inputs.input),
static_cast<void*>(outputs.weight),
static_cast<const void*>(inputs.output),
s_conf);
}
} // namespace ck_tile::builder::test

View File

@@ -25,8 +25,10 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
run_ck_tile_test<Builder>({

View File

@@ -12,6 +12,7 @@
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
namespace ckf = ck_tile::builder::factory;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
@@ -31,12 +32,49 @@ constexpr auto ALGORITHM =
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
constexpr auto TWO_STAGE_ALGORITHM =
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = true});
constexpr ckt::Args<SIGNATURE> Args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using TwoStageBuilder = ckb::ConvBuilder<SIGNATURE, TWO_STAGE_ALGORITHM>;
using TwoStageInstance = TwoStageBuilder::Instance;
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<SIGNATURE, TWO_STAGE_ALGORITHM>;
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_2D_FP16_NHWGC, Create)
@@ -61,38 +99,47 @@ TEST(BwdWeight_2D_FP16_NHWGC, Create)
});
}
TEST(ElementWiseOp, CreateBwdWeightTwoStageElementwiseOp)
{
cku::run_ck_tile_test<ElementwiseOpBuilder>({"elementwise_kernel",
"4096_256_4_4_64_4_256",
"UnaryConvert",
"kPad_1",
"ElementWiseDefaultPolicy"});
}
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
{
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(Args);
auto outputs = ckt::alloc_outputs(Args);
auto reference = ckt::alloc_outputs(Args);
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
ckt::init_inputs(Args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
EXPECT_THAT(ckt::run(conv, Args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
}
TEST(BwdWeight_TwoStage_2D_FP16_NHWGC, Execution)
{
auto inputs = ckt::alloc_inputs(Args);
auto outputs = ckt::alloc_outputs(Args);
auto reference = ckt::alloc_outputs(Args);
ckt::init_inputs(Args, inputs.get());
auto conv = TwoStageInstance{};
auto elementwise_op = ElementwiseOpInstance{};
EXPECT_THAT(ckt::run(conv, elementwise_op, Args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get()));
}

View File

@@ -24,8 +24,10 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({

View File

@@ -31,8 +31,10 @@ constexpr auto ALGORITHM =
.with_tile_thread_block(cku::FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::FwdTileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false});
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -377,6 +377,8 @@ struct TileOptimizations
bool split_image;
// Explicit gemm for 1x1, stride=0, pad=0 cases
bool explicit_gemm;
// Two-stage kernels
bool two_stage;
};
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);

View File

@@ -13,6 +13,7 @@ class ConvInstanceTemplateParams:
warp_tile,
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
scalar_per_vector,
@@ -27,6 +28,7 @@ class ConvInstanceTemplateParams:
self.warp_tile = warp_tile
self.double_smem_buffer = double_smem_buffer
self.num_wave_groups = num_wave_groups
self.is_two_stage_instance = is_two_stage_instance
self.pipeline_version = pipeline_version
self.scheduler = scheduler
self.scalar_per_vector = scalar_per_vector
@@ -39,7 +41,8 @@ class ConvInstanceTemplateParams:
explicit_gemm = "true" if self.explicit_gemm else "false"
split_image = "true" if self.split_image else "false"
num_groups_to_merge = str(self.num_groups_to_merge)
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}"
two_stage_instance = "true" if self.is_two_stage_instance else "false"
return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}"
def get_specialization(self):
namespace = "ckb::TileConvSpecialization::"
@@ -270,6 +273,8 @@ def parse_fwd_instances(instances, problem_name):
print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.")
continue
is_two_stage = False
conv = ConvInstanceTemplateParams(
spec,
[m_per_block, n_per_block, k_per_block],
@@ -277,6 +282,7 @@ def parse_fwd_instances(instances, problem_name):
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],
@@ -343,7 +349,7 @@ def parse_bwd_weight_instances(instances, problem_name):
num_groups_to_merge = 1
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = args[6]
block_gemm_pipeline_scheduler = args[6]
blk_gemm_pipeline_version = args[7]
else:
spec = args[11]
@@ -372,20 +378,29 @@ def parse_bwd_weight_instances(instances, problem_name):
num_groups_to_merge = int(args[44])
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = args[39]
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
elif is_two_stage_instance:
print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.")
continue
if len(args) != 46:
raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" +
f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.")
num_groups_to_merge = args[41]
# Block GEMM pipeline parameters
block_gemm_pipeline_scheduler = args[39]
blk_gemm_pipeline_version = args[40]
else:
# Regular V1 XDL CShuffle instance
if len(args) != 43:
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}")
raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" +
f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.")
num_groups_to_merge = 1
# Block GEMM pipeline parameters
blk_gemm_pipeline_schduler = "Intrawave"
block_gemm_pipeline_scheduler = "Intrawave"
blk_gemm_pipeline_version = "v1"
# Common part to all solvers.
@@ -393,15 +408,15 @@ def parse_bwd_weight_instances(instances, problem_name):
# Sanity check for Block GEMM pipeline parameters
# Scheduler must be either Intrawave or Interwave.
# Version must be from v1 to v5
if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}")
if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]:
raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}")
if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]:
raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}")
split_image = instance.find("Large") != -1
double_smem_buffer = blk_gemm_pipeline_version == "v4"
num_wave_groups = 1
scheduler = blk_gemm_pipeline_schduler
scheduler = block_gemm_pipeline_scheduler
pipeline_version = blk_gemm_pipeline_version.upper()
# OLd CK pipeline version V5 maps to V6 for CK Tile
@@ -428,6 +443,7 @@ def parse_bwd_weight_instances(instances, problem_name):
[m_per_xdl, n_per_xdl, k_per_xdl],
double_smem_buffer,
num_wave_groups,
is_two_stage_instance,
pipeline_version,
scheduler,
[a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector],

View File

@@ -6,6 +6,7 @@
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
namespace ckf = ck_tile::builder::factory;
namespace ck_tile::builder::profiling {

View File

@@ -1,7 +1,21 @@
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using ConvInstance = Builder::Instance;
auto conv = ConvInstance{};
auto result = [&]<auto Sig, auto Alg>() {
if constexpr(ConvDirectionIsBackwardWeight<Sig> && Alg.optimizations.two_stage)
{
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
auto elementwise_op = ElementwiseOpInstance{};
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
}
else
{
return ckt::run(conv, args, inputs, outputs, s_conf);
}
}.template operator()<SIGNATURE, ALGORITHM>();
auto conv = Instance{};
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
@@ -108,6 +109,19 @@ struct ElementWiseKernel
ignore = input_sizes;
return true;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "elementwise_kernel",
Problem::GetName(),
"policy",
Policy::GetName()
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
};
} // namespace ck_tile

View File

@@ -24,6 +24,11 @@ struct ElementWiseDefaultPolicy
sequence<0, 3>>{} // Yield
);
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
return "ElementWiseDefaultPolicy";
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -21,6 +22,19 @@ struct ElementWisePipelineProblem
using BlockShape = remove_cvref_t<BlockShape_>;
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
static constexpr bool kPad = kPad_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_',
BlockShape::GetName(),
"op",
ElementWiseOperation::name,
"kPad",
kPad
);
// clang-format on
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -25,6 +26,15 @@ struct ElementWiseShape
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "shape",
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
);
// clang-format on
}
};
} // namespace ck_tile