mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc2f95620d
commit
e2f5ab8000
@@ -155,6 +155,7 @@ concept TileOptimizationsDescriptor = requires(T t) {
|
||||
{ t.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ t.split_image } -> std::convertible_to<bool>;
|
||||
{ t.explicit_gemm } -> std::convertible_to<bool>;
|
||||
{ t.two_stage } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
@@ -295,6 +296,7 @@ concept SpecifiesTileOptimizations = requires {
|
||||
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ T::optimizations.split_image } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.two_stage } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
@@ -68,6 +70,10 @@ struct ConvTileFactory
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using ConvOutDataType = std::conditional_t<OPTIMIZATIONS.two_stage,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
@@ -103,7 +109,7 @@ struct ConvTileFactory
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType,
|
||||
ConvOutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
@@ -126,4 +132,33 @@ struct ConvTileFactory
|
||||
ConvEpilogue>::Instance;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION = LATEST_API_VERSION>
|
||||
struct ElementwiseOpTileFactory
|
||||
{
|
||||
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
|
||||
|
||||
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
using XDataType = Types::AccDataType;
|
||||
using WorkspaceDataType = Types::AccDataType;
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using YDataType = Types::EDataType;
|
||||
using BlockTile = ck_tile::sequence<BLOCK.per_block.m * BLOCK.per_block.n>;
|
||||
using BlockWarps = ck_tile::sequence<BLOCK_GEMM.warps.m * BLOCK_GEMM.warps.n>;
|
||||
using WarpTile = ck_tile::sequence<BLOCK_GEMM.warp_tile.m * BLOCK_GEMM.warp_tile.n>;
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceDataType>;
|
||||
|
||||
// Conversion from X -> Y.
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
WorkspaceDataType,
|
||||
YDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
using Instance = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
@@ -34,6 +34,7 @@ struct TileOptimizations
|
||||
int num_groups_to_merge = 1;
|
||||
bool split_image = false;
|
||||
bool explicit_gemm = false;
|
||||
bool two_stage = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
@@ -181,7 +182,8 @@ consteval TileOptimizations SetTileOptimizations()
|
||||
|
||||
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
|
||||
.split_image = OPT.split_image,
|
||||
.explicit_gemm = OPT.explicit_gemm};
|
||||
.explicit_gemm = OPT.explicit_gemm,
|
||||
.two_stage = OPT.two_stage};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -91,6 +91,100 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
}
|
||||
}
|
||||
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
auto& elementwise_op,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
const ck_tile::stream_config s_conf)
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
using ElementwiseOp = std::remove_reference_t<decltype(elementwise_op)>;
|
||||
using WorkspaceDataType = typename ElementwiseOp::ComputeDataType;
|
||||
using CDataType = typename ElementwiseOp::YDataType;
|
||||
using BlockShape = typename ElementwiseOp::Problem::BlockShape;
|
||||
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
|
||||
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
|
||||
host_args(param, input, weight, {}, output, args.k_batch);
|
||||
|
||||
// Set-up for elementwise op kernel.
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
std::accumulate(host_args.filter_spatial_lengths_.begin(),
|
||||
host_args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(host_args.G_ * host_args.K_ * host_args.C_ *
|
||||
spatial_lengths_accum * sizeof(WorkspaceDataType));
|
||||
ck_tile::GroupedConvBwdWeightHostArgs ws_args =
|
||||
ck_tile::GroupedConvBwdWeightHostArgs(host_args);
|
||||
auto c_ptr = ws_args.wei_ptr;
|
||||
ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(ws_args);
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {
|
||||
static_cast<ck_tile::index_t>(host_args.G_ * host_args.K_),
|
||||
static_cast<ck_tile::index_t>(host_args.C_ * spatial_lengths_accum)};
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseOp::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockShape::kBlockM;
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceDataType*>(ws_args.wei_ptr));
|
||||
auto input_size = ck_tile::make_tuple(shape[0], shape[1]);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseOp::IsSupportedArgument(input_size))
|
||||
{
|
||||
return RunResult::not_supported("unsupported ck_tile arguments for elementwise op");
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(ws_args.wei_ptr,
|
||||
0,
|
||||
shape[0] * shape[1] * sizeof(WorkspaceDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
|
||||
s_conf,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<minimum_occupancy>(elementwise_op,
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr))));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
|
||||
@@ -149,4 +243,28 @@ template <auto SIGNATURE>
|
||||
s_conf);
|
||||
}
|
||||
|
||||
/// @brief `run()` specialization for two-stage backwards weight convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
auto& elementwise_op,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
return detail::run(conv,
|
||||
elementwise_op,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<void*>(outputs.weight),
|
||||
static_cast<const void*>(inputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user