mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[ck_tile] remove duplicate functions in ck_tile (#3311)
* [ck_tile] remove duplicated shuffle_b and shuffle_b_permuteN * [ck_tile] move get_k_warp to gemm_shape * resolve code rebase error
This commit is contained in:
@@ -11,26 +11,6 @@
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
{
|
||||
@@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
|
||||
static constexpr bool TransposeC = false; // transpose c is not supported
|
||||
@@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
@@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
struct BShuffleGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::K_Warp_Tile;
|
||||
};
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
// Host-side preshuffle of B
|
||||
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
|
||||
auto b_shuffle_host = ck_tile::shuffle_b<BShuffleGemmConfig>(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
Reference in New Issue
Block a user