mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Wavelet (inter-wave consumer-producer) GEMM (#310)
* wavelet gemm programming model support for CK * GEMM pipeline update for wavelet progrmmaing model * Updated wavelet programming pipeline * fixes for global-write for math-wave * fixed bug in global writes * Updated comments for better readability * fixed clang format errors * added block_lds without barrier sync * clean * clean * clean * clean * refactor * prototype 4 layouts fix default stride all problem sizes tidy move file update build script restore old file fix build * refactor standalone test to use gemm test harness * simplify gemm test * update build script * remove redundant * early return when cmd arg doesn't match * tidy * report failure when result not validated * tidy * Add comment depicting B2C mapping pattern. * Formatting & comments. * Comparison with custom B2C mapping pattern. * Example for wavelet gemm. * Add wavelet to Gemm standalone test. * Remove debug code. * Remove dangling #endif directive. Co-authored-by: root <Raman Jana> Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Anthony Chang <ac.chang@outlook.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,8 @@ using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
@@ -29,7 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
|
||||
// clang-format on
|
||||
// // clang-format on
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
@@ -40,7 +42,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceGemmInstance = DeviceGemmInstance0;
|
||||
using DeviceGemmInstance = DeviceGemmInstance1;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
Reference in New Issue
Block a user