[rocm-libraries] ROCm/rocm-libraries#5516 (commit ff3afda)

[CK_TILE, CK_BUILDER] Add bwd data to CK Tile profiler
 (#5516)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

We want close the performance gap between old CK and CK Tile for bwd
data convolutions. To achieve this, we need tow things

- Configurations for the old CK kernel instances such that we can map
them into CK Tile instances.
- Support in CK profiler to run the CK Tile instance with the same API
as for old CK instances.

## Technical Details

Extracted kernel configurations from old CK. The codegen python script
for CK Tile convs is extended to support also bwd data. The generated
instances are added to the CMake build (target
`device_grouped_conv_bwd_data_tile_instances`).
A new profiler op (`grouped_conv_bwd_data_tile`) has been added to the
CK Profiler. The API is same as for old CK's profiler op
`grouped_conv_bwd_data`.
This commit is contained in:
Ville Pietilä
2026-03-25 14:36:11 +00:00
committed by assistant-librarian[bot]
parent 1834e318da
commit ec2dbfbfde
29 changed files with 1588 additions and 956 deletions

View File

@@ -296,45 +296,45 @@ struct InstanceTraits<
oss << ","
<< detail::conv_bwd_data_spec_name(
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
oss << "," << kDoPadGemmM;
oss << "," << kDoPadGemmN;
oss << "," << kNumGemmKPrefetchStage;
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kAK1; // 19. AK1
oss << "," << kBK1; // 19. BK1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
oss << "," << kDoPadGemmM; // 15. GEMM padding for M dimension
oss << "," << kDoPadGemmN; // 16. GEMM padding for N dimension
oss << "," << kNumGemmKPrefetchStage; // 17. Number of GEMM K prefetch stages
oss << "," << kBlockSize; // 18. BlockSize
oss << "," << kMPerBlock; // 19. MPerBlock
oss << "," << kNPerBlock; // 20. NPerBlock
oss << "," << kK0PerBlock; // 21. K0PerBlock
oss << "," << kAK1; // 22. AK1
oss << "," << kBK1; // 23. BK1
oss << "," << kMPerXDL; // 24. MPerXDL
oss << "," << kNPerXDL; // 25. NPerXDL
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 28.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 29.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 30.
oss << "," << kABlockTransferSrcVectorDim; // 31.
oss << "," << kABlockTransferSrcScalarPerVector; // 32.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 33.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 35.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 36.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 37.
oss << "," << kBBlockTransferSrcVectorDim; // 38.
oss << "," << kBBlockTransferSrcScalarPerVector; // 39.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 40.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43.
oss << ","
<< detail::sequence_name<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42.
oss << "," << kNumGemmKPrefetchStage; // 41.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched
oss << "," << detail::type_name<ComputeTypeA>(); // 44.
oss << "," << detail::type_name<ComputeTypeB>(); // 45.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47.
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 44.
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 45.
oss << "," << kNumGemmKPrefetchStage; // 46.
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 47. LoopSched
oss << "," << detail::type_name<ComputeTypeA>(); // 48.
oss << "," << detail::type_name<ComputeTypeB>(); // 49.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 50.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 51.
oss << ">";

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/conv/args.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/error.hpp"
/// This file deals with the backward data-specific details of running grouped
/// convolution backwards data operations. It mainly defines the data
/// structures (`Input` and `Output`), initialization, and validation. Note
/// that for this operation specifically, many of the operations are
/// implemented automatically via testing_reflect.hpp.
namespace ck_tile::builder::test {
/// @brief `Inputs` specialization for backwards data convolution.
///
/// @tparam SIGNATURE Backwards data convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* weight;
void* output;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
}
};
/// @brief `Outputs` specialization for backwards data convolution.
///
/// @tparam SIGNATURE Backward data convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* input;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Outputs<SIGNATURE>::input);
}
};
/// @brief `init_inputs()` specialization for backwards convolution.
///
/// @tparam SIGNATURE Backward data convolution signature.
///
/// @see init_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
}
} // namespace ck_tile::builder::test

View File

@@ -6,6 +6,7 @@
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/testing/conv/bwd_data.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
@@ -35,6 +36,29 @@ concept CkTileConvInstance = requires(Conv&) {
{ Conv::BlockSize() };
};
template <auto SIGNATURE>
std::size_t gemm_split_k_output_size(auto kargs)
{
std::size_t zeroing_size = 0;
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
std::end(kargs.wei_g_k_c_xs_lengths.data),
1,
std::multiplies<std::size_t>());
}
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
zeroing_size = std::accumulate(std::begin(kargs.in_g_n_c_wis_lengths.data),
std::end(kargs.in_g_n_c_wis_lengths.data),
1,
std::multiplies<std::size_t>());
}
return zeroing_size;
}
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
@@ -58,10 +82,8 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
return RunResult::not_supported("unsupported ck_tile arguments");
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
std::end(kargs.wei_g_k_c_xs_lengths.data),
1,
std::multiplies<std::size_t>());
const std::size_t zeroing_size = gemm_split_k_output_size<SIGNATURE>(kargs);
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
@@ -75,6 +97,18 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
s_conf.stream_id_));
}
}
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
if(kargs.k_batch > 1)
{
ck_tile::hip_check_error(
hipMemsetAsync(kargs.in_ptr,
0,
zeroing_size * sizeof(typename Types::EDataType),
s_conf.stream_id_));
}
}
};
constexpr index_t minimum_occupancy =
@@ -293,4 +327,26 @@ template <auto SIGNATURE>
s_conf);
}
/// @brief `run()` specialization for backwards data convolution and CK Tile.
///
/// @tparam SIGNATURE Backward data convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ConvDirectionIsBackwardData<SIGNATURE>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config s_conf = {})
{
return detail::run(conv,
args,
static_cast<void*>(outputs.input),
static_cast<const void*>(inputs.weight),
static_cast<const void*>(inputs.output),
s_conf);
}
} // namespace ck_tile::builder::test

View File

@@ -134,4 +134,26 @@ template <auto SIGNATURE>
return detail::run(conv, args, inputs.input, outputs.weight, inputs.output);
}
/// @brief Concept for checking whether this is the reference convolution
/// backward data implementation.
template <typename Conv, auto SIGNATURE>
concept RefConvBwdDataInstance =
detail::RefConvInstance<Conv, SIGNATURE, void*, const void*, const void*> &&
ConvDirectionIsBackwardData<SIGNATURE>;
/// @brief `run()` specialization for the reference backward data implementation.
///
/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards data.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
[[nodiscard]] RunResult run(RefConvBwdDataInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
return detail::run(conv, args, outputs.input, inputs.weight, inputs.output);
}
} // namespace ck_tile::builder::test