[rocm-libraries] ROCm/rocm-libraries#4295 (commit fa2cfc8)

[CK_TILE] Refactor `UniversalGemm::MakeA/B/C/DBlockViews` to
 allow caller to pass desciptors directly (#4295)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Currently `UniversalGemmKernel::MakeA/B/C/DBlockViews` directly create
tensor views from strides and sizes. This refactors the descriptor
creation out and add overloaded definitions, allowing descriptors to be
created separately by the caller instead of passing explicit strides,
with no functional changes.

This will enable further refactoring of `RunGemm` to do likewise,
enabling derived kernels like BatchedContractionKernel to avoid creating
separate versions (PR
[#3457](https://github.com/ROCm/composable_kernel/pull/3457)).

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

Since the logic within the MakeXBlockviews chains together operations on
tuples, and thus the descriptors are also passed as such, adding a
template parameter for the type of the input tuple was the simplest
option to enable the overload without too much verbiage. However, for
`MakeCBlockView` this adds a complications as the templated definitions
are prone to overlap. This for now is avoided by just moving the
arguments around for the descriptor version, which avoids the collision.
It's not a great solution, so feel free to suggest a better one.
This commit is contained in:
Matti Eskelinen
2026-02-24 20:44:27 +00:00
committed by assistant-librarian[bot]
parent 4c626aeaa6
commit cd12e8e31f

View File

@@ -689,39 +689,150 @@ struct UniversalGemmKernel
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
}
template <typename ALayout>
CK_TILE_DEVICE static auto
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m)
MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size)
{
// Step 1: Create tensor views for A tensors (from MakeGemmTensorViews)
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(make_tuple(M, k_size),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_descriptor(make_tuple(k_size, M),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}
template <typename BLayout>
CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N,
const index_t K,
const index_t stride,
const index_t k_size)
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
make_tuple(N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
return transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(k_size, N),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
make_tuple(N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
return transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
}
else
{
if constexpr(GemmPipeline::Preshuffle)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = N * K / kFlatK;
return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
return make_naive_tensor_descriptor(make_tuple(N, k_size),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
}
template <typename DLayout, index_t VectorSizeD>
CK_TILE_DEVICE static auto
MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
{
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
}
CK_TILE_DEVICE static auto
MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride)
{
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(make_tuple(M, N),
make_tuple(stride, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{});
}
}
template <typename AsTensorDesc>
CK_TILE_DEVICE static auto
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
const AsTensorDesc& as_desc,
const index_t i_m)
{
// Step 1: Create tensor views
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
},
number<NumATensor>{});
// Step 2: Create padded views (from MakeGemmPadViews)
// Step 2: Create padded views
const auto& as_pad_view = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
@@ -742,7 +853,7 @@ struct UniversalGemmKernel
},
number<NumATensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
// Step 3: Create tile windows
const auto& as_block_window = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
@@ -767,101 +878,38 @@ struct UniversalGemmKernel
}
CK_TILE_DEVICE static auto
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m)
{
// Step 1: Create tensor descriptors for A tensors
const auto& as_tensor_desc = generate_tuple(
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
return MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size);
},
number<NumATensor>{});
return MakeABlockWindows(as_ptr, as_tensor_desc, i_m);
}
template <typename BsTensorDesc>
CK_TILE_DEVICE static auto
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
const BsTensorDesc& bs_desc,
const index_t i_n)
{
// Step 1: Create tensor views for B tensors (from MakeGemmTensorViews)
// Step 1: Create tensor views
const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(k_size, kargs.N),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
if constexpr(GemmPipeline::Preshuffle)
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
},
number<NumBTensor>{});
// Step 2: Create padded views (from MakeGemmPadViews)
// Step 2: Create padded views
const auto& bs_pad_view = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -882,7 +930,7 @@ struct UniversalGemmKernel
},
number<NumBTensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
// Step 3: Create tile windows
const auto& bs_block_window = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
@@ -918,38 +966,39 @@ struct UniversalGemmKernel
return bs_block_window;
}
CK_TILE_DEVICE static auto
MakeBBlockWindows(const std::array<const BDataType*, NumBTensor>& bs_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_n)
{
const auto& bs_tensor_desc = generate_tuple(
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
return MakeDefaultBTensorDescriptor<BiLayout>(
kargs.N, kargs.K, kargs.stride_Bs[i], k_size);
},
number<NumBTensor>{});
return MakeBBlockWindows(bs_ptr, bs_tensor_desc, i_n);
}
template <typename DsTensorDesc>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const DsTensorDesc& ds_desc,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor views for D tensors (from MakeGemmTensorViews)
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
return make_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]), ds_desc[i]);
},
number<NumDTensor>{});
// Step 2: Create padded views (from MakeGemmPadViews)
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -970,7 +1019,7 @@ struct UniversalGemmKernel
},
number<NumDTensor>{});
// Step 3: Create tile windows (from MakeGemmTileWindows)
// Step 3: Create tile windows
const auto& ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
@@ -994,35 +1043,34 @@ struct UniversalGemmKernel
return ds_block_window;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<1>{});
}
}();
const auto& ds_tensor_desc = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
return MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
kargs.M, kargs.N, kargs.stride_Ds[i]);
},
number<NumDTensor>{});
// Step 2: Create padded view (from MakeGemmPadViews)
return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n);
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ETensorDesc>
CK_TILE_DEVICE static auto MakeCBlockWindows(
EDataType* e_ptr,
const index_t i_m,
const index_t i_n,
const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads
{
// Step 1: Create tensor view for E/C tensor
const auto& e_tensor_view =
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc);
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
@@ -1040,7 +1088,7 @@ struct UniversalGemmKernel
}
}();
// Step 3: Create tile window (from MakeGemmTileWindows)
// Step 3: Create tile window
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
@@ -1049,6 +1097,17 @@ struct UniversalGemmKernel
return e_block_window;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E);
return MakeCBlockWindows<DstInMemOp>(e_ptr, i_m, i_n, e_tensor_desc);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*