mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Wavelet (inter-wave consumer-producer) GEMM (#310)
* wavelet gemm programming model support for CK
* GEMM pipeline update for wavelet progrmmaing model
* Updated wavelet programming pipeline
* fixes for global-write for math-wave
* fixed bug in global writes
* Updated comments for better readability
* fixed clang format errors
* added block_lds without barrier sync
* clean
* clean
* clean
* clean
* refactor
* prototype
4 layouts
fix default stride
all problem sizes
tidy
move file
update build script
restore old file
fix build
* refactor standalone test to use gemm test harness
* simplify gemm test
* update build script
* remove redundant
* early return when cmd arg doesn't match
* tidy
* report failure when result not validated
* tidy
* Add comment depicting B2C mapping pattern.
* Formatting & comments.
* Comparison with custom B2C mapping pattern.
* Example for wavelet gemm.
* Add wavelet to Gemm standalone test.
* Remove debug code.
* Remove dangling #endif directive.
Co-authored-by: root <Raman Jana>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Anthony Chang <ac.chang@outlook.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 1cfa87608a]
This commit is contained in:
@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
arg.block_2_etile_map_);
|
||||
};
|
||||
|
||||
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
|
||||
@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user