mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Add generate_identity_sequences helper and replace lambdas with named functors (#4828)
## Summary
- Add `generate_identity_sequences<N>()` helper that returns
`Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>`
- Replace lambdas with named functors in `transform_tensor_descriptor`
- Add `unpack_and_merge_sequences` helper functor
- Reduces `transform_tensor_descriptor` instantiations from 388 to 32
(92% reduction)
## Motivation
Multiple call sites use `generate_tuple([](auto i) { return
Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda
instantiations.
Additionally, each lambda in `transform_tensor_descriptor` creates a
unique closure type, causing the function to be instantiated separately
for every call site. Named functors share a single type, so the compiler
reuses the same instantiation.
## Changes
### Part 1: generate_identity_sequences helper
- Replaces common lambda pattern for generating identity sequences
- Each lambda expression creates a unique closure type, causing separate
template instantiations at every call site
- Named helper shares a single type across all uses
### Part 2: Named functors in transform_tensor_descriptor
- Add `unpack_and_merge_sequences` helper to replace lambda in
`GetNumOfHiddenDimension`
- Use `generate_identity_sequences` in `matrix_padder.hpp`
## Test Plan
- [x] Added 7 unit tests:
- 4 tests for `generate_identity_sequences`
- 3 tests for `unpack_and_merge_sequences`
- [ ] Waiting for full CI
## Related PRs
This PR merges the functionality from:
- ROCm/composable_kernel#3588 (generate_identity_sequences helper)
- ROCm/composable_kernel#3589 (Named functors in
transform_tensor_descriptor)
Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times)
**Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and
ROCm/composable_kernel#3589, which can be closed once this is merged.
---
🔁 Imported from
[ROCm/composable_kernel#3628](https://github.com/ROCm/composable_kernel/pull/3628)
🧑💻 Originally authored by @tenpercent
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -657,8 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -707,8 +706,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
@@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
@@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user