mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Add permuteN optimization to remove lds operation in c_shuffle (#2764)
* permuteN optimization to remove lds operation in c_shuffle * add the change log --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -276,6 +276,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -298,6 +300,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
|
||||
@@ -241,6 +241,26 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
template <typename CDataType>
|
||||
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
@@ -346,7 +366,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
std::cout << "Run with PermuteN" << std::endl;
|
||||
return shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run without PermuteN" << std::endl;
|
||||
return shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
}();
|
||||
// shuffled buffer B for device implementation
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user