[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)

[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK
 Tile profiler (#5237)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.

## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Ville Pietilä
2026-03-13 01:21:08 +00:00
committed by assistant-librarian[bot]
parent fc2f95620d
commit e2f5ab8000
16 changed files with 336 additions and 50 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
@@ -108,6 +109,19 @@ struct ElementWiseKernel
ignore = input_sizes;
return true;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "elementwise_kernel",
Problem::GetName(),
"policy",
Policy::GetName()
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
};
} // namespace ck_tile

View File

@@ -24,6 +24,11 @@ struct ElementWiseDefaultPolicy
sequence<0, 3>>{} // Yield
);
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
return "ElementWiseDefaultPolicy";
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -21,6 +22,19 @@ struct ElementWisePipelineProblem
using BlockShape = remove_cvref_t<BlockShape_>;
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
static constexpr bool kPad = kPad_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_',
BlockShape::GetName(),
"op",
ElementWiseOperation::name,
"kPad",
kPad
);
// clang-format on
}
};
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -25,6 +26,15 @@ struct ElementWiseShape
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "shape",
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
);
// clang-format on
}
};
} // namespace ck_tile