mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[ck_tile] remove duplicate functions in ck_tile (#3311)
* [ck_tile] remove duplicated shuffle_b and shuffle_b_permuteN * [ck_tile] move get_k_warp to gemm_shape * resolve code rebase error
This commit is contained in:
@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
return combined_hash;
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
Reference in New Issue
Block a user