[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)

[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK
 Tile profiler (#5237)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.

## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Ville Pietilä
2026-03-13 01:21:08 +00:00
committed by assistant-librarian[bot]
parent fc2f95620d
commit e2f5ab8000
16 changed files with 336 additions and 50 deletions

View File

@@ -155,6 +155,7 @@ concept TileOptimizationsDescriptor = requires(T t) {
{ t.num_groups_to_merge } -> std::convertible_to<int>;
{ t.split_image } -> std::convertible_to<bool>;
{ t.explicit_gemm } -> std::convertible_to<bool>;
{ t.two_stage } -> std::convertible_to<bool>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
@@ -295,6 +296,7 @@ concept SpecifiesTileOptimizations = requires {
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
{ T::optimizations.split_image } -> std::convertible_to<bool>;
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
{ T::optimizations.two_stage } -> std::convertible_to<bool>;
};
template <typename T>

View File

@@ -8,6 +8,8 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/versions.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
@@ -68,6 +70,10 @@ struct ConvTileFactory
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using ConvOutDataType = std::conditional_t<OPTIMIZATIONS.two_stage,
typename Types::AccDataType,
typename Types::EDataType>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
@@ -103,7 +109,7 @@ struct ConvTileFactory
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
ConvOutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
@@ -126,4 +132,33 @@ struct ConvTileFactory
ConvEpilogue>::Instance;
};
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION = LATEST_API_VERSION>
struct ElementwiseOpTileFactory
{
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using XDataType = Types::AccDataType;
using WorkspaceDataType = Types::AccDataType;
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using YDataType = Types::EDataType;
using BlockTile = ck_tile::sequence<BLOCK.per_block.m * BLOCK.per_block.n>;
using BlockWarps = ck_tile::sequence<BLOCK_GEMM.warps.m * BLOCK_GEMM.warps.n>;
using WarpTile = ck_tile::sequence<BLOCK_GEMM.warp_tile.m * BLOCK_GEMM.warp_tile.n>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
// Conversion from X -> Y.
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
WorkspaceDataType,
YDataType,
ElementwiseShape,
XElementwiseOperation>;
using Instance = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
};
} // namespace ck_tile::builder::factory

View File

@@ -34,6 +34,7 @@ struct TileOptimizations
int num_groups_to_merge = 1;
bool split_image = false;
bool explicit_gemm = false;
bool two_stage = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
@@ -181,7 +182,8 @@ consteval TileOptimizations SetTileOptimizations()
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
.split_image = OPT.split_image,
.explicit_gemm = OPT.explicit_gemm};
.explicit_gemm = OPT.explicit_gemm,
.two_stage = OPT.two_stage};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -91,6 +91,100 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
}
}
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
auto& elementwise_op,
const Args<SIGNATURE>& args,
InDataType* input,
WeiDataType* weight,
OutDataType* output,
const ck_tile::stream_config s_conf)
{
using Conv = std::remove_reference_t<decltype(conv)>;
using ElementwiseOp = std::remove_reference_t<decltype(elementwise_op)>;
using WorkspaceDataType = typename ElementwiseOp::ComputeDataType;
using CDataType = typename ElementwiseOp::YDataType;
using BlockShape = typename ElementwiseOp::Problem::BlockShape;
const auto param = args.to_ck_tile_conv_param();
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
host_args(param, input, weight, {}, output, args.k_batch);
// Set-up for elementwise op kernel.
const ck_tile::index_t spatial_lengths_accum =
std::accumulate(host_args.filter_spatial_lengths_.begin(),
host_args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
ck_tile::DeviceMem ws_m_n_dev_buf(host_args.G_ * host_args.K_ * host_args.C_ *
spatial_lengths_accum * sizeof(WorkspaceDataType));
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
ck_tile::GroupedConvBwdWeightHostArgs(host_args);
auto c_ptr = ws_args.wei_ptr;
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto kargs = Conv::MakeKernelArgs(ws_args);
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {
static_cast<ck_tile::index_t>(host_args.G_ * host_args.K_),
static_cast<ck_tile::index_t>(host_args.C_ * spatial_lengths_accum)};
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseOp::BlockSize();
constexpr ck_tile::index_t elements_per_block = BlockShape::kBlockM;
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
// Check if the kernel configuration is supported
if(!ElementwiseOp::IsSupportedArgument(input_size))
{
return RunResult::not_supported("unsupported ck_tile arguments for elementwise op");
}
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(ws_args.wei_ptr,
0,
shape[0] * shape[1] * sizeof(WorkspaceDataType),
s_conf.stream_id_));
}
}
};
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs),
ck_tile::make_kernel<minimum_occupancy>(elementwise_op,
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr))));
}
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
@@ -149,4 +243,28 @@ template <auto SIGNATURE>
s_conf);
}
/// @brief `run()` specialization for two-stage backwards weight convolution and CK Tile.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
auto& elementwise_op,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config s_conf = {})
{
return detail::run(conv,
elementwise_op,
args,
static_cast<const void*>(inputs.input),
static_cast<void*>(outputs.weight),
static_cast<const void*>(inputs.output),
s_conf);
}
} // namespace ck_tile::builder::test