From 4e0fd5241ae68487f8d40fb4453bda8f57c603e1 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Thu, 15 Jan 2026 11:32:50 +0000 Subject: [PATCH] Separate tensor descriptor creation from the tensor view creation This adds utility functions to construct default tensor descriptors for A, B, C and D tensors and refactors the Make{A,B,C,D}BlockWindows to call make_tensor_view using the utility functions instead of directly calling make_naive_tensor_view, allowing for further refactors later. --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 281 +++++++++--------- 1 file changed, 144 insertions(+), 137 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 9583ac8a3f..a028989c82 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -661,6 +661,136 @@ struct UniversalGemmKernel return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } + template + CK_TILE_DEVICE static auto + MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, M), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + + template + CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N, + const index_t K, + const index_t stride, + const index_t k_size) + { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, N), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = N * K / kFlatK; + + return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + } + } + + template + CK_TILE_DEVICE static auto + MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(stride, 1), number{}, number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, M), make_tuple(stride, 1), number{}, number<1>{}); + } + } + + CK_TILE_DEVICE static auto + MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + // TODO: enable vector write for C in ColMajor + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), // arguments not matching with flatmm. + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{}); + } + } + CK_TILE_DEVICE static auto MakeABlockWindows(const std::array& as_ptr, const KernelArgs& kargs, @@ -672,24 +802,10 @@ struct UniversalGemmKernel [&](auto i) { using AiLayout = remove_cvref_t>; using AiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(kargs.M, k_size), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(k_size, kargs.M), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); - } + + return make_tensor_view( + static_cast(as_ptr[i]), + MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size)); }, number{}); @@ -749,87 +865,10 @@ struct UniversalGemmKernel [&](auto i) { using BiLayout = remove_cvref_t>; using BiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(k_size, kargs.N), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - } + return make_tensor_view( + static_cast(bs_ptr[i]), + MakeDefaultBTensorDescriptor( + kargs.N, kargs.K, kargs.stride_Bs[i], k_size)); }, number{}); @@ -900,24 +939,10 @@ struct UniversalGemmKernel [&](auto i) { using DiLayout = remove_cvref_t>; using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } + return make_tensor_view( + static_cast(ds_ptr[i]), + MakeDefaultDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i])); }, number{}); @@ -973,26 +998,8 @@ struct UniversalGemmKernel const index_t i_n) { // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); + const auto& e_tensor_view = make_tensor_view( + e_ptr, MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E)); // Step 2: Create padded view (from MakeGemmPadViews) const auto& e_pad_view = [&]() {