mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5030 (commit 8e02a26)
[CK] Replace tuple value construction with tuple_element_t type extraction [1A] (#5030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ### Rationale CK's device operation instance registration uses `add_device_operation_instances` at ~1,850 call sites to register GPU kernel configurations. The existing implementation constructs `std::tuple` values just to extract their types via `decltype`, then copy-constructs each instance into `make_unique`. This is wasteful — only the types matter, not the values — and forces the compiler to instantiate the full `std::tuple` constructor and `std::get` machinery at every call site. ### What changed - Replace `remove_cvref_t<decltype(std::get<i>(tuple_obj))>` with `std::tuple_element_t<i.value, TupleType>`, which extracts the type directly without constructing any values - Replace copy-from-default `make_unique<T>(value)` with direct default construction `make_unique<T>()` — all CK device operation instances are stateless structs with configuration encoded in template parameters - Add `static_assert(std::is_default_constructible_v<NewOpInstance>)` to enforce this contract at compile time with a clear error message - Add Doxygen documentation for this high-traffic public API ### Value - Eliminates unnecessary template instantiation of `std::tuple` constructors and `std::get` across ~1,850 call sites - Establishes a cleaner, more intention-revealing pattern for type-only tuple usage - The `static_assert` prevents silent breakage if a non-default-constructible type is ever added - No runtime behavior change — zero risk ### Files changed (9) - `add_device_operation_instance.hpp`: Core pattern change - 3 example files, 3 reduce instance headers, 1 convolution header, 1 profiler header ## Test plan - [ ] Existing CI tests cover all ~1,850 call sites (GEMM, reduce, softmax, convolution) - [ ] `static_assert` provides compile-time validation stronger than runtime tests - [ ] No runtime behavior change — stateless struct default construction is identical to copy-from-default - [ ] Compatible with both `std::tuple` and `ck::type_list` containers 🤖 Generated with [Claude Code](https://claude.com/claude-code) ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e2ce0cad54
commit
e0d11b969b
@@ -129,13 +129,11 @@ bool reduce_blockwise_test(bool do_verification,
|
||||
bool matched = false;
|
||||
int result = 0;
|
||||
|
||||
const auto tuple_object = reduce_shape_instances{};
|
||||
|
||||
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
|
||||
if(matched)
|
||||
return;
|
||||
|
||||
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
|
||||
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
|
||||
|
||||
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
|
||||
return;
|
||||
|
||||
@@ -127,13 +127,11 @@ bool reduce_multiblock_atomic_add_test(bool do_verification,
|
||||
bool matched = false;
|
||||
int result = 0;
|
||||
|
||||
const auto tuple_object = reduce_shape_instances{};
|
||||
|
||||
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
|
||||
if(matched)
|
||||
return;
|
||||
|
||||
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
|
||||
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
|
||||
|
||||
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
|
||||
return;
|
||||
|
||||
@@ -129,13 +129,11 @@ bool reduce_threadwise_multi_d_test(bool do_verification,
|
||||
bool matched = false;
|
||||
int result = 0;
|
||||
|
||||
const auto tuple_object = reduce_shape_instances{};
|
||||
|
||||
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
|
||||
if(matched)
|
||||
return;
|
||||
|
||||
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
|
||||
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
|
||||
|
||||
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
|
||||
return;
|
||||
|
||||
@@ -14,14 +14,18 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
/**
|
||||
* @brief Register device operation instances from a type container.
|
||||
* @tparam BaseOp The base class that all operation instances must derive from.
|
||||
* @tparam NewOpInstances A std::tuple (or ck::type_list) of device operation types.
|
||||
* Only the type is used; the parameter value is unused (retained for type deduction).
|
||||
*/
|
||||
template <typename BaseOp, typename NewOpInstances>
|
||||
void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_instances,
|
||||
const NewOpInstances& new_op_instances)
|
||||
const NewOpInstances& /*new_op_instances*/)
|
||||
{
|
||||
ck::static_for<0, std::tuple_size_v<NewOpInstances>, 1>{}([&](auto i) {
|
||||
const auto new_op_instance = std::get<i>(new_op_instances);
|
||||
|
||||
using NewOpInstance = remove_cvref_t<decltype(new_op_instance)>;
|
||||
using NewOpInstance = std::tuple_element_t<i.value, NewOpInstances>;
|
||||
if constexpr(std::is_same_v<NewOpInstance, std::nullptr_t>)
|
||||
{
|
||||
return; // We can use nullptr_t to enable trailing comma
|
||||
@@ -29,8 +33,13 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
|
||||
else
|
||||
{
|
||||
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
|
||||
"wrong! NewOpInstance should be derived from BaseOp");
|
||||
op_instances.push_back(std::make_unique<NewOpInstance>(new_op_instance));
|
||||
"add_device_operation_instances: NewOpInstance must derive from BaseOp");
|
||||
static_assert(
|
||||
std::is_default_constructible_v<NewOpInstance>,
|
||||
"add_device_operation_instances: NewOpInstance must be default-constructible; "
|
||||
"registration default-constructs instances and ignores tuple values, so store "
|
||||
"configuration in template parameters instead of constructor arguments.");
|
||||
op_instances.push_back(std::make_unique<NewOpInstance>());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -45,9 +45,11 @@ void add_explicit_gemm_device_operation_instances(
|
||||
DeviceGemmOp>;
|
||||
|
||||
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
|
||||
"wrong! NewOpInstance should be derived from BaseOp");
|
||||
"NewOpInstance must derive from BaseOp");
|
||||
static_assert(std::is_default_constructible_v<NewOpInstance>,
|
||||
"NewOpInstance must be default-constructible");
|
||||
|
||||
op_instances.push_back(std::make_unique<NewOpInstance>(NewOpInstance{}));
|
||||
op_instances.push_back(std::make_unique<NewOpInstance>());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -89,13 +89,12 @@ void add_device_reduce_instance_blockwise(
|
||||
{
|
||||
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
|
||||
[&](auto i) {
|
||||
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
|
||||
reduce_configuration_1_instances_blockwise{}))>;
|
||||
using cfg1 = std::tuple_element_t<i.value, reduce_configuration_1_instances_blockwise>;
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_blockwise{}))>;
|
||||
using cfg2 =
|
||||
std::tuple_element_t<j.value, reduce_configuration_2_instances_blockwise>;
|
||||
|
||||
using ReduceOpInstance =
|
||||
DeviceReduceMultiBlock<InDataType,
|
||||
@@ -119,8 +118,7 @@ void add_device_reduce_instance_blockwise(
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(
|
||||
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -90,14 +90,15 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
static_for<0,
|
||||
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
|
||||
1>{}([&](auto i) {
|
||||
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
|
||||
reduce_configuration_1_instances_multiblock_atomic_add{}))>;
|
||||
using cfg1 =
|
||||
std::tuple_element_t<i.value, reduce_configuration_1_instances_multiblock_atomic_add>;
|
||||
|
||||
static_for<0,
|
||||
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
|
||||
1>{}([&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_multiblock_atomic_add{}))>;
|
||||
using cfg2 =
|
||||
std::tuple_element_t<j.value,
|
||||
reduce_configuration_2_instances_multiblock_atomic_add>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
@@ -120,7 +121,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
@@ -77,8 +77,7 @@ void add_device_reduce_instance_threadwise(
|
||||
|
||||
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
|
||||
[&](auto j) {
|
||||
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
|
||||
reduce_configuration_2_instances_threadwise{}))>;
|
||||
using cfg2 = std::tuple_element_t<j.value, reduce_configuration_2_instances_threadwise>;
|
||||
|
||||
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
|
||||
AccDataType,
|
||||
@@ -99,7 +98,7 @@ void add_device_reduce_instance_threadwise(
|
||||
cfg2::InSrcVectorSize_,
|
||||
cfg2::OutDstVectorSize_>;
|
||||
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
|
||||
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -488,13 +488,11 @@ bool profile_reduce_impl(bool do_verification,
|
||||
using tuple_of_description_instances =
|
||||
tensor_operation::device::instance::reduce_description_instances;
|
||||
|
||||
const auto tuple_object = tuple_of_description_instances{};
|
||||
|
||||
static_for<0, std::tuple_size<tuple_of_description_instances>::value, 1>{}([&](auto i) {
|
||||
if(matched)
|
||||
return;
|
||||
|
||||
using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
|
||||
using descType = std::tuple_element_t<i.value, tuple_of_description_instances>;
|
||||
|
||||
if(!description_match(
|
||||
descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))
|
||||
|
||||
Reference in New Issue
Block a user