mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
22 lines
780 B
C++
22 lines
780 B
C++
|
|
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
|
using ConvInstance = Builder::Instance;
|
|
|
|
auto conv = ConvInstance{};
|
|
|
|
auto result = [&]<auto Sig, auto Alg>() {
|
|
if constexpr(ConvDirectionIsBackwardWeight<Sig> && Alg.optimizations.two_stage)
|
|
{
|
|
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
|
|
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
|
|
auto elementwise_op = ElementwiseOpInstance{};
|
|
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
|
|
}
|
|
else
|
|
{
|
|
return ckt::run(conv, args, inputs, outputs, s_conf);
|
|
}
|
|
}.template operator()<SIGNATURE, ALGORITHM>();
|
|
|
|
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());
|