Wave Tile Transfer supporting global load with transpose (#3027)

* Initial implementation:

 - add new thread group transfer supporting transpose instruction
 - refactor AB transfer to switch between thread and wave tiles methods

* Add some comments and remove explicit wave and lane calculations

* Remove compiler option for performance

* fp16 example: use tuned instance

* Missing cleanup

* Integrate wave transfer in existing gemm and batched gemm instances

* Add fast instances

* extend implementation for 8 bit datatypes

packed types not supported

* Address review comments

* Optimize pipeline v1 and re-introduce compiler option

* Disable wave tile approach for b scale gemm

* Fix for clang20

* Avoid code duplication of amd_global_load_transpose_to_vgpr function

[ROCm/composable_kernel commit: 440358c168]
This commit is contained in:
Enrico Degregori
2025-10-16 20:33:56 +02:00
committed by GitHub
parent 06d76b160e
commit 1d9320c8f3
15 changed files with 1513 additions and 720 deletions

View File

@@ -26,17 +26,18 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64,
64, 8, 8,
256,
128, 256, 64,
8, 8,
16, 16,
4, 2,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
2, 8,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
1, 1,
S<1, 64, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::