[CK_TILE] Add permuteN optimization to remove lds operation in c_shuffle (#2764)

* permuteN optimization to remove lds operation in c_shuffle

* add the change log

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
lalala-sh
2025-09-09 13:02:48 +08:00
committed by GitHub
parent 92b07380d3
commit 75570d0fa8
5 changed files with 189 additions and 4 deletions

View File

@@ -276,6 +276,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
@@ -298,6 +300,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>

View File

@@ -241,6 +241,26 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
@@ -346,7 +366,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if constexpr(preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if constexpr(GemmConfig::TiledMMAPermuteN)
{
std::cout << "Run with PermuteN" << std::endl;
return shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
std::cout << "Run without PermuteN" << std::endl;
return shuffle_b<GemmConfig>(b_k_n);
}
}();
// shuffled buffer B for device implementation
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}