[ck_tile] remove duplicate functions in ck_tile (#3311)

* [ck_tile] remove duplicated shuffle_b and shuffle_b_permuteN

* [ck_tile] move get_k_warp to gemm_shape

* resolve code rebase error
This commit is contained in:
linqunAMD
2025-12-15 23:13:00 +08:00
committed by GitHub
parent fe35ba5dac
commit 6d7299ff78
10 changed files with 123 additions and 306 deletions

View File

@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile

View File

@@ -43,4 +43,26 @@ struct TileGemmShape
}
};
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
constexpr index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
else
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
#endif
#endif
}
} // namespace ck_tile