Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -19,22 +19,22 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
PassThrough, PassThrough, PassThrough, GemmSpec,
256,
128, 256, 64,
8, 8,
16, 16,
2, 8,
S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 8, 8, 1,
1, 1,
S<1, 64, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;