[rocm-libraries] ROCm/rocm-libraries#5030 (commit 8e02a26)

[CK] Replace tuple value construction with tuple_element_t
 type extraction [1A] (#5030)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

### Rationale
CK's device operation instance registration uses
`add_device_operation_instances` at ~1,850
call sites to register GPU kernel configurations. The existing
implementation constructs
`std::tuple` values just to extract their types via `decltype`, then
copy-constructs each
instance into `make_unique`. This is wasteful — only the types matter,
not the values — and
forces the compiler to instantiate the full `std::tuple` constructor and
`std::get` machinery
at every call site.

### What changed
- Replace `remove_cvref_t<decltype(std::get<i>(tuple_obj))>` with
`std::tuple_element_t<i.value, TupleType>`, which extracts the type
directly without constructing any values
- Replace copy-from-default `make_unique<T>(value)` with direct default
construction `make_unique<T>()` — all CK device operation instances are
stateless structs with configuration encoded in template parameters
- Add `static_assert(std::is_default_constructible_v<NewOpInstance>)` to
enforce this contract at compile time with a clear error message
- Add Doxygen documentation for this high-traffic public API

### Value
- Eliminates unnecessary template instantiation of `std::tuple`
constructors and `std::get` across ~1,850 call sites
- Establishes a cleaner, more intention-revealing pattern for type-only
tuple usage
- The `static_assert` prevents silent breakage if a
non-default-constructible type is ever added
- No runtime behavior change — zero risk

### Files changed (9)
- `add_device_operation_instance.hpp`: Core pattern change
- 3 example files, 3 reduce instance headers, 1 convolution header, 1
profiler header

## Test plan
- [ ] Existing CI tests cover all ~1,850 call sites (GEMM, reduce,
softmax, convolution)
- [ ] `static_assert` provides compile-time validation stronger than
runtime tests
- [ ] No runtime behavior change — stateless struct default construction
is identical to copy-from-default
- [ ] Compatible with both `std::tuple` and `ck::type_list` containers

🤖 Generated with [Claude Code](https://claude.com/claude-code)
## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-03-06 16:28:22 +00:00
committed by assistant-librarian[bot]
parent e2ce0cad54
commit e0d11b969b
9 changed files with 35 additions and 34 deletions

View File

@@ -129,13 +129,11 @@ bool reduce_blockwise_test(bool do_verification,
bool matched = false;
int result = 0;
const auto tuple_object = reduce_shape_instances{};
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
return;

View File

@@ -127,13 +127,11 @@ bool reduce_multiblock_atomic_add_test(bool do_verification,
bool matched = false;
int result = 0;
const auto tuple_object = reduce_shape_instances{};
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
return;

View File

@@ -129,13 +129,11 @@ bool reduce_threadwise_multi_d_test(bool do_verification,
bool matched = false;
int result = 0;
const auto tuple_object = reduce_shape_instances{};
static_for<0, std::tuple_size<reduce_shape_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using ShapeType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
using ShapeType = std::tuple_element_t<i.value, reduce_shape_instances>;
if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size())
return;

View File

@@ -14,14 +14,18 @@ namespace tensor_operation {
namespace device {
namespace instance {
/**
* @brief Register device operation instances from a type container.
* @tparam BaseOp The base class that all operation instances must derive from.
* @tparam NewOpInstances A std::tuple (or ck::type_list) of device operation types.
* Only the type is used; the parameter value is unused (retained for type deduction).
*/
template <typename BaseOp, typename NewOpInstances>
void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_instances,
const NewOpInstances& new_op_instances)
const NewOpInstances& /*new_op_instances*/)
{
ck::static_for<0, std::tuple_size_v<NewOpInstances>, 1>{}([&](auto i) {
const auto new_op_instance = std::get<i>(new_op_instances);
using NewOpInstance = remove_cvref_t<decltype(new_op_instance)>;
using NewOpInstance = std::tuple_element_t<i.value, NewOpInstances>;
if constexpr(std::is_same_v<NewOpInstance, std::nullptr_t>)
{
return; // We can use nullptr_t to enable trailing comma
@@ -29,8 +33,13 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
else
{
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
"wrong! NewOpInstance should be derived from BaseOp");
op_instances.push_back(std::make_unique<NewOpInstance>(new_op_instance));
"add_device_operation_instances: NewOpInstance must derive from BaseOp");
static_assert(
std::is_default_constructible_v<NewOpInstance>,
"add_device_operation_instances: NewOpInstance must be default-constructible; "
"registration default-constructs instances and ignores tuple values, so store "
"configuration in template parameters instead of constructor arguments.");
op_instances.push_back(std::make_unique<NewOpInstance>());
}
});
}

View File

@@ -45,9 +45,11 @@ void add_explicit_gemm_device_operation_instances(
DeviceGemmOp>;
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
"wrong! NewOpInstance should be derived from BaseOp");
"NewOpInstance must derive from BaseOp");
static_assert(std::is_default_constructible_v<NewOpInstance>,
"NewOpInstance must be default-constructible");
op_instances.push_back(std::make_unique<NewOpInstance>(NewOpInstance{}));
op_instances.push_back(std::make_unique<NewOpInstance>());
});
}

View File

@@ -89,13 +89,12 @@ void add_device_reduce_instance_blockwise(
{
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
reduce_configuration_1_instances_blockwise{}))>;
using cfg1 = std::tuple_element_t<i.value, reduce_configuration_1_instances_blockwise>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_blockwise{}))>;
using cfg2 =
std::tuple_element_t<j.value, reduce_configuration_2_instances_blockwise>;
using ReduceOpInstance =
DeviceReduceMultiBlock<InDataType,
@@ -119,8 +118,7 @@ void add_device_reduce_instance_blockwise(
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
});
});
};

View File

@@ -90,14 +90,15 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_for<0,
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
1>{}([&](auto i) {
using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
reduce_configuration_1_instances_multiblock_atomic_add{}))>;
using cfg1 =
std::tuple_element_t<i.value, reduce_configuration_1_instances_multiblock_atomic_add>;
static_for<0,
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using cfg2 =
std::tuple_element_t<j.value,
reduce_configuration_2_instances_multiblock_atomic_add>;
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
AccDataType,
@@ -120,7 +121,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
});
});
};

View File

@@ -77,8 +77,7 @@ void add_device_reduce_instance_threadwise(
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
[&](auto j) {
using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
reduce_configuration_2_instances_threadwise{}))>;
using cfg2 = std::tuple_element_t<j.value, reduce_configuration_2_instances_threadwise>;
using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
AccDataType,
@@ -99,7 +98,7 @@ void add_device_reduce_instance_threadwise(
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
device_op_instances.push_back(std::make_unique<ReduceOpInstance>());
});
};

View File

@@ -488,13 +488,11 @@ bool profile_reduce_impl(bool do_verification,
using tuple_of_description_instances =
tensor_operation::device::instance::reduce_description_instances;
const auto tuple_object = tuple_of_description_instances{};
static_for<0, std::tuple_size<tuple_of_description_instances>::value, 1>{}([&](auto i) {
if(matched)
return;
using descType = remove_cvref_t<decltype(std::get<i>(tuple_object))>;
using descType = std::tuple_element_t<i.value, tuple_of_description_instances>;
if(!description_match(
descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex))