mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5237 (commit ef10dc6)
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc2f95620d
commit
e2f5ab8000
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
@@ -108,6 +109,19 @@ struct ElementWiseKernel
|
||||
ignore = input_sizes;
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "elementwise_kernel",
|
||||
Problem::GetName(),
|
||||
"policy",
|
||||
Policy::GetName()
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,6 +24,11 @@ struct ElementWiseDefaultPolicy
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return "ElementWiseDefaultPolicy";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -21,6 +22,19 @@ struct ElementWisePipelineProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
|
||||
static constexpr bool kPad = kPad_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_',
|
||||
BlockShape::GetName(),
|
||||
"op",
|
||||
ElementWiseOperation::name,
|
||||
"kPad",
|
||||
kPad
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -25,6 +26,15 @@ struct ElementWiseShape
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "shape",
|
||||
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user