Fix the vector load & fix the gfx950 compv4 error (#2831)

This commit is contained in:
Thomas Ning
2025-09-12 11:48:45 -07:00
committed by GitHub
parent 321627aec5
commit 1894a0dbc3
3 changed files with 35 additions and 20 deletions

View File

@@ -20,17 +20,18 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr bool single_load_tr_length =
(DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) ==
(WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size());
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
((is_a_load_tr<Problem> || is_b_load_tr<Problem>) && !single_load_tr_length)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,