mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
[CK TILE] Add new function get_k_warp_tile_for_preshuffle_b
This commit is contained in:
@@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -94,7 +94,10 @@ int main(int argc, char* argv[])
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
arg_parser.print();
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
|
||||
@@ -66,4 +66,24 @@ constexpr index_t get_k_warp_tile()
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, index_t N_Warp_Tile>
|
||||
constexpr index_t get_k_warp_tile_for_preshuffle_b()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(N_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
// K value is determined by the maximum bytes that can be loaded in a single instruction
|
||||
// This K value is sufficient for MFMA/WMMA shapes: 16x16x16, 16x16x32, 32x32x16
|
||||
const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes
|
||||
const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
|
||||
const int KLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
|
||||
return kMaxElementsPerLoad * KLanePerWarp;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user