[CK TILE] simplify function GetKBPerLoad

This commit is contained in:
Cong Ma
2026-01-22 17:49:39 -05:00
parent 080fa14140
commit dc83e285e1

View File

@@ -39,21 +39,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
}
constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
/* The k_b_per_load should meet the requirement that each thread loads 16 bytes in
* Preshuffle B */
static_assert(k_b_per_load * sizeof(BDataType) == 16);
return k_b_per_load;
}
template <typename Problem>