mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK TILE] simplify function GetKBPerLoad
This commit is contained in:
@@ -39,21 +39,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
}
|
||||
|
||||
constexpr index_t k_b_per_load =
|
||||
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
|
||||
|
||||
/* The k_b_per_load should meet the requirement that each thread loads 16 bytes in
|
||||
* Preshuffle B */
|
||||
static_assert(k_b_per_load * sizeof(BDataType) == 16);
|
||||
|
||||
return k_b_per_load;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user