Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -125,7 +125,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM)
TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{127, 128};
constexpr int N = 512;
constexpr int K = 437;
@@ -139,7 +139,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK)
TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{127, 128};
constexpr int N = 512;
constexpr int K = 437;
@@ -153,7 +153,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK)
TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{127, 128};
constexpr int N = 512;
constexpr int K = 437;
@@ -169,7 +169,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK)
TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{127, 128};
constexpr int N = 512;
constexpr int K = 437;