[CK TILE] Add new function get_k_warp_tile_for_preshuffle_b

This commit is contained in:
Cong Ma
2026-01-22 12:36:40 -05:00
parent f41f37da96
commit 080fa14140
3 changed files with 25 additions and 2 deletions

View File

@@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;

View File

@@ -94,7 +94,10 @@ int main(int argc, char* argv[])
auto result = arg_parser.parse(argc, argv);
if(!result)
{
arg_parser.print();
return -1;
}
try
{

View File

@@ -66,4 +66,24 @@ constexpr index_t get_k_warp_tile()
#endif
}
template <typename PrecType, index_t N_Warp_Tile>
constexpr index_t get_k_warp_tile_for_preshuffle_b()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(N_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
// K value is determined by the maximum bytes that can be loaded in a single instruction
// This K value is sufficient for MFMA/WMMA shapes: 16x16x16, 16x16x32, 32x32x16
const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes
const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
const int KLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
return kMaxElementsPerLoad * KLanePerWarp;
#endif
}
} // namespace ck_tile