Separate tensor descriptor creation from the tensor view creation

This adds utility functions to construct default tensor descriptors for
A, B, C and D tensors and refactors the Make{A,B,C,D}BlockWindows to call make_tensor_view
using the utility functions instead of directly calling
make_naive_tensor_view, allowing for further refactors later.
This commit is contained in:
Matti Eskelinen
2026-01-15 11:32:50 +00:00
parent 993d3e2f0e
commit 4e0fd5241a

View File

@@ -661,6 +661,136 @@ struct UniversalGemmKernel
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
}
template <typename ALayout>
CK_TILE_DEVICE static auto
MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size)
{
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(make_tuple(M, k_size),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_descriptor(make_tuple(k_size, M),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}
template <typename BLayout>
CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N,
const index_t K,
const index_t stride,
const index_t k_size)
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
make_tuple(N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
return transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(k_size, N),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1),
make_tuple(N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
return transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
}
else
{
if constexpr(GemmPipeline::Preshuffle)
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = N * K / kFlatK;
return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
return make_naive_tensor_descriptor(make_tuple(N, k_size),
make_tuple(stride, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
}
template <typename DLayout, index_t VectorSizeD>
CK_TILE_DEVICE static auto
MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
{
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
}
CK_TILE_DEVICE static auto
MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride)
{
// TODO: enable vector write for C in ColMajor
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(M, N), // arguments not matching with flatmm.
make_tuple(stride, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{});
}
}
CK_TILE_DEVICE static auto
MakeABlockWindows(const std::array<const ADataType*, NumATensor>& as_ptr,
const KernelArgs& kargs,
@@ -672,24 +802,10 @@ struct UniversalGemmKernel
[&](auto i) {
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_As[i], 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
return make_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]),
MakeDefaultATensorDescriptor<AiLayout>(kargs.M, kargs.stride_As[i], k_size));
},
number<NumATensor>{});
@@ -749,87 +865,10 @@ struct UniversalGemmKernel
[&](auto i) {
using BiLayout = remove_cvref_t<std::tuple_element_t<i.value, BsLayout>>;
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(k_size, kargs.N),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = k_size / K1;
constexpr index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), b_n_k_desc);
}
else
{
if constexpr(GemmPipeline::Preshuffle)
{
index_t kFlatK =
GemmPipeline::BlockGemmShape::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
bs_ptr[i],
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_Bs[i], 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]),
MakeDefaultBTensorDescriptor<BiLayout>(
kargs.N, kargs.K, kargs.stride_Bs[i], k_size));
},
number<NumBTensor>{});
@@ -900,24 +939,10 @@ struct UniversalGemmKernel
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
return make_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
MakeDefaultDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
kargs.M, kargs.N, kargs.stride_Ds[i]));
},
number<NumDTensor>{});
@@ -973,26 +998,8 @@ struct UniversalGemmKernel
const index_t i_n)
{
// Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews)
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<1>{});
}
}();
const auto& e_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr, MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E));
// Step 2: Create padded view (from MakeGemmPadViews)
const auto& e_pad_view = [&]() {