[ck_tile] remove duplicate functions in ck_tile (#3311)

* [ck_tile] remove duplicated shuffle_b and shuffle_b_permuteN

* [ck_tile] move get_k_warp to gemm_shape

* resolve code rebase error
This commit is contained in:
linqunAMD
2025-12-15 23:13:00 +08:00
committed by GitHub
parent fe35ba5dac
commit 6d7299ff78
10 changed files with 123 additions and 306 deletions

View File

@@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
return traits;
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t N_Warp)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
N_Warp,
N_Warp_Tile,
NRepeat,
k_ / K_Warp_Tile,
divisor,
K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -111,21 +111,30 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
struct GemmConfig
{
ck_tile::index_t N_Warp_Tile;
ck_tile::index_t K_Warp_Tile;
ck_tile::index_t N_Tile;
ck_tile::index_t N_Warp;
};
for(const auto& callable : callables)
{
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
GemmConfig gemmConfig = {};
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if(config.permuteN)
{
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
}
else
{
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
return ck_tile::shuffle_b(b_k_n, gemmConfig);
}
}();