// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "device_prop.hpp" #include namespace ck_tile { namespace detail { template struct b_contiguous_items_per_access { // Default: 16 / sizeof(T) static constexpr int value = 16 / static_cast(sizeof(T)); }; template struct b_contiguous_items_per_access> { // PackedSize specified static constexpr int value = GemmConfig::BContiguousItemsPerAccess; }; } // namespace detail template auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) { if(t->get_lengths().size() != 2) { throw std::runtime_error("Host tensor is not rank 2 tensor."); } int m_ = t->get_lengths()[0]; int aqk_ = t->get_lengths()[1]; if(aqk_ % block_aq_k != 0) { throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); } ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); std::copy(t->begin(), t->end(), t_view.begin()); return ck_tile::reference_permute(t_view, {1, 0, 2}); } template auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) { const auto& lengths = t->get_lengths(); const size_t rank = lengths.size(); // Validate block_bq_k divisibility based on rank int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1; if(bqk_dim < 0) { throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " + std::to_string(rank)); } if(bqk_dim % block_bq_k != 0) { throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k."); } // For TilePermuteN if(rank == 5) { // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk] ck_tile::HostTensor t_view({static_cast(lengths[0]), static_cast(lengths[1]), static_cast(lengths[2]), static_cast(lengths[3]), bqk_dim / block_bq_k, block_bq_k}); std::copy(t->begin(), t->end(), t_view.begin()); return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5}); } else // rank == 2 { // Handle 2D tensor: [bqk, n] int n_ = lengths[1]; ck_tile::HostTensor t_view({n_, bqk_dim / block_bq_k, block_bq_k}); std::copy(t->begin(), t->end(), t_view.begin()); return ck_tile::reference_permute(t_view, {1, 0, 2}); } } template auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } else if(ck_tile::is_gfx11_supported()) { int divisor = 1; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / gemmConfig.K_Warp_Tile, divisor, gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } else { constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; constexpr int ItemsPerAccess = std::min(detail::b_contiguous_items_per_access::value, GemmConfig::K_Warp_Tile / KLane); ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / ItemsPerAccess, ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); } } template auto shuffle_b(const ck_tile::HostTensor& t) { return shuffle_b(t, GemmConfig{}); } template auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int bqk_ = t.get_lengths()[0]; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; ck_tile::HostTensor t_view({n_ / (GemmConfig::N_Tile / group_n), GemmConfig::N_Warp, GemmConfig::N_Warp_Tile / group_n, NRepeat, bqk_}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); } template auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig, number) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; int kABK1PerLane = min(16 / static_cast(sizeof(T)), gemmConfig.K_Warp_Tile / divisor); int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 6, 2, 7}); } else { assert(NRepeat % BlockedXDLNPerWarp == 0); constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; constexpr int ItemsPerAccess = std::min(detail::b_contiguous_items_per_access::value, GemmConfig::K_Warp_Tile / KLane); ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat / BlockedXDLNPerWarp, BlockedXDLNPerWarp, k_ / ItemsPerAccess, ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } } template auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { return shuffle_b_permuteN(t, GemmConfig{}, number{}); } template auto shuffle_b_v0(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; if(ck_tile::is_gfx11_supported()) { int divisor = 1; ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / FlatmmConfig::K_Warp_Tile, divisor, FlatmmConfig::K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } else { constexpr int MaxVecSize = 16 / sizeof(T); // because ck_tile::get_warp_size returns 64 in host side int KLane = (ck_tile::is_wave32() ? (ck_tile::get_warp_size() / 2) : (ck_tile::get_warp_size())) / FlatmmConfig::N_Warp_Tile; int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / ItemsPerAccess, ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); } } template auto shuffle_b_v1(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; constexpr int MaxVecSize = 16 / sizeof(T); constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, FlatmmConfig::N_Warp, FlatmmConfig::N_Warp_Tile, NRepeat, k_ / ItemsPerAccess, ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); } } // namespace ck_tile