mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
[rocm-libraries] ROCm/rocm-libraries#5516 (commit ff3afda)
[CK_TILE, CK_BUILDER] Add bwd data to CK Tile profiler (#5516) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation We want close the performance gap between old CK and CK Tile for bwd data convolutions. To achieve this, we need tow things - Configurations for the old CK kernel instances such that we can map them into CK Tile instances. - Support in CK profiler to run the CK Tile instance with the same API as for old CK instances. ## Technical Details Extracted kernel configurations from old CK. The codegen python script for CK Tile convs is extended to support also bwd data. The generated instances are added to the CMake build (target `device_grouped_conv_bwd_data_tile_instances`). A new profiler op (`grouped_conv_bwd_data_tile`) has been added to the CK Profiler. The API is same as for old CK's profiler op `grouped_conv_bwd_data`.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1834e318da
commit
ec2dbfbfde
@@ -296,45 +296,45 @@ struct InstanceTraits<
|
||||
oss << ","
|
||||
<< detail::conv_bwd_data_spec_name(
|
||||
kConvBwdDataSpecialization); // 14. ConvBackwardDataSpecialization
|
||||
oss << "," << kDoPadGemmM;
|
||||
oss << "," << kDoPadGemmN;
|
||||
oss << "," << kNumGemmKPrefetchStage;
|
||||
oss << "," << kBlockSize; // 15. BlockSize
|
||||
oss << "," << kMPerBlock; // 16. MPerBlock
|
||||
oss << "," << kNPerBlock; // 17. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 18. K0PerBlock
|
||||
oss << "," << kAK1; // 19. AK1
|
||||
oss << "," << kBK1; // 19. BK1
|
||||
oss << "," << kMPerXDL; // 20. MPerXDL
|
||||
oss << "," << kNPerXDL; // 21. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 27.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 34.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
|
||||
oss << "," << kDoPadGemmM; // 15. GEMM padding for M dimension
|
||||
oss << "," << kDoPadGemmN; // 16. GEMM padding for N dimension
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. Number of GEMM K prefetch stages
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 21. K0PerBlock
|
||||
oss << "," << kAK1; // 22. AK1
|
||||
oss << "," << kBK1; // 23. BK1
|
||||
oss << "," << kMPerXDL; // 24. MPerXDL
|
||||
oss << "," << kNPerXDL; // 25. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 28.
|
||||
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 29.
|
||||
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 30.
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 31.
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32.
|
||||
oss << "," << kABlockTransferDstScalarPerVectorK1; // 33.
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 35.
|
||||
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 36.
|
||||
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 37.
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38.
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39.
|
||||
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 40.
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41.
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42.
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43.
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 40.
|
||||
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 42.
|
||||
oss << "," << kNumGemmKPrefetchStage; // 41.
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 43. LoopSched
|
||||
oss << "," << detail::type_name<ComputeTypeA>(); // 44.
|
||||
oss << "," << detail::type_name<ComputeTypeB>(); // 45.
|
||||
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 46.
|
||||
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 47.
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 44.
|
||||
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 45.
|
||||
oss << "," << kNumGemmKPrefetchStage; // 46.
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 47. LoopSched
|
||||
oss << "," << detail::type_name<ComputeTypeA>(); // 48.
|
||||
oss << "," << detail::type_name<ComputeTypeB>(); // 49.
|
||||
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 50.
|
||||
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 51.
|
||||
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/conv/args.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/error.hpp"
|
||||
|
||||
/// This file deals with the backward data-specific details of running grouped
|
||||
/// convolution backwards data operations. It mainly defines the data
|
||||
/// structures (`Input` and `Output`), initialization, and validation. Note
|
||||
/// that for this operation specifically, many of the operations are
|
||||
/// implemented automatically via testing_reflect.hpp.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief `Inputs` specialization for backwards data convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards data convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* weight;
|
||||
void* output;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for backwards data convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backward data convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Outputs<SIGNATURE>::input);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for backwards convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backward data convolution signature.
|
||||
///
|
||||
/// @see init_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardData<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/testing/conv/bwd_data.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
@@ -35,6 +36,29 @@ concept CkTileConvInstance = requires(Conv&) {
|
||||
{ Conv::BlockSize() };
|
||||
};
|
||||
|
||||
template <auto SIGNATURE>
|
||||
std::size_t gemm_split_k_output_size(auto kargs)
|
||||
{
|
||||
std::size_t zeroing_size = 0;
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
|
||||
std::end(kargs.wei_g_k_c_xs_lengths.data),
|
||||
1,
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
zeroing_size = std::accumulate(std::begin(kargs.in_g_n_c_wis_lengths.data),
|
||||
std::end(kargs.in_g_n_c_wis_lengths.data),
|
||||
1,
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
return zeroing_size;
|
||||
}
|
||||
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
@@ -58,10 +82,8 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
|
||||
std::end(kargs.wei_g_k_c_xs_lengths.data),
|
||||
1,
|
||||
std::multiplies<std::size_t>());
|
||||
|
||||
const std::size_t zeroing_size = gemm_split_k_output_size<SIGNATURE>(kargs);
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
@@ -75,6 +97,18 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.in_ptr,
|
||||
0,
|
||||
zeroing_size * sizeof(typename Types::EDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -293,4 +327,26 @@ template <auto SIGNATURE>
|
||||
s_conf);
|
||||
}
|
||||
|
||||
/// @brief `run()` specialization for backwards data convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Backward data convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
return detail::run(conv,
|
||||
args,
|
||||
static_cast<void*>(outputs.input),
|
||||
static_cast<const void*>(inputs.weight),
|
||||
static_cast<const void*>(inputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -134,4 +134,26 @@ template <auto SIGNATURE>
|
||||
return detail::run(conv, args, inputs.input, outputs.weight, inputs.output);
|
||||
}
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// backward data implementation.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvBwdDataInstance =
|
||||
detail::RefConvInstance<Conv, SIGNATURE, void*, const void*, const void*> &&
|
||||
ConvDirectionIsBackwardData<SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for the reference backward data implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards data.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(RefConvBwdDataInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
return detail::run(conv, args, outputs.input, inputs.weight, inputs.output);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user