Wavelet (inter-wave consumer-producer) GEMM (#310)

* wavelet gemm programming model support for CK

* GEMM pipeline update for wavelet progrmmaing model

* Updated wavelet programming pipeline

* fixes for global-write for math-wave

* fixed bug in global writes

* Updated comments for better readability

* fixed clang format errors

* added block_lds without barrier sync

* clean

* clean

* clean

* clean

* refactor

* prototype

4 layouts

fix default stride

all problem sizes

tidy

move file

update build script

restore old file

fix build

* refactor standalone test to use gemm test harness

* simplify gemm test

* update build script

* remove redundant

* early return when cmd arg doesn't match

* tidy

* report failure when result not validated

* tidy

* Add comment depicting B2C mapping pattern.

* Formatting & comments.

* Comparison with custom B2C mapping pattern.

* Example for wavelet gemm.

* Add wavelet to Gemm standalone test.

* Remove debug code.

* Remove dangling #endif directive.

Co-authored-by: root <Raman Jana>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Anthony Chang <ac.chang@outlook.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Raman R jana
2023-01-18 12:00:02 -06:00
committed by GitHub
parent d66421fe34
commit 1cfa87608a
15 changed files with 1652 additions and 6 deletions

View File

@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});

View File

@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);