mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
Add asynchronous XOR shuffle support to the Async GEMM pipeline and the MX GEMM pipeline (#7112) ## Motivation The goal of this work is to apply XOR shuffle (swizzle) to the current `comp_async` GEMM pipeline and the `gemm_mx` pipeline. XOR swizzling has been helpful to avoid LDS bank conflicts, as data are redistributed across LDS banks, such that simultaneous threads accessing different rows land on different LDS banks. ## Technical Details A similar approach to the work in the existing eight-waves pipeline was followed. Currently, XOR swizzle support is available for FP8 and BF8 types. FP4 support is also available for MX GEMM. Should the types not match, or should the async vector width be of an unsupported size, then the pipeline falls through to the previously existing ('unswizzled') path. ## Test Plan Execute `test_ck_tile_gemm_pipeline_comp_async` for the Async GEMM pipeline. Execute `test_ck_tile_mx_gemm_fp8` and `test_ck_tile_mx_gemm_fp4` for the MX GEMM pipeline. ## Test Result The tests passed successfully in the `Alola` cluster with MI350 hardware. ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
140 lines
6.9 KiB
C++
140 lines
6.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
|
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
|
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
|
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
|
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
|
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
|
|
#include "ck_tile/core/arch/amd_cluster_load.hpp"
|
|
#include "ck_tile/core/arch/amd_tdm_descriptor.hpp"
|
|
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
|
#include "ck_tile/core/arch/amd_wave_read_first_lane.hpp"
|
|
#include "ck_tile/core/arch/arch.hpp"
|
|
#include "ck_tile/core/arch/barrier.hpp"
|
|
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
|
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp"
|
|
#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp"
|
|
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
|
|
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_selector.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp"
|
|
#include "ck_tile/core/arch/mma/wmma/wmma_transforms.hpp"
|
|
#include "ck_tile/core/arch/utility.hpp"
|
|
#include "ck_tile/core/arch/workgroup_barrier.hpp"
|
|
#include "ck_tile/core/config.hpp"
|
|
#include "ck_tile/core/container/array.hpp"
|
|
#include "ck_tile/core/container/container_helper.hpp"
|
|
#include "ck_tile/core/container/map.hpp"
|
|
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
|
#include "ck_tile/core/container/multi_index.hpp"
|
|
#include "ck_tile/core/container/sequence.hpp"
|
|
#include "ck_tile/core/container/span.hpp"
|
|
#include "ck_tile/core/container/static_array.hpp"
|
|
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
|
#include "ck_tile/core/container/thread_buffer.hpp"
|
|
#include "ck_tile/core/container/tuple.hpp"
|
|
#include "ck_tile/core/numeric/bfloat16.hpp"
|
|
#include "ck_tile/core/numeric/e4m3.hpp"
|
|
#include "ck_tile/core/numeric/e5m3.hpp"
|
|
#include "ck_tile/core/numeric/e8m0.hpp"
|
|
#include "ck_tile/core/numeric/ext_vector_base.hpp"
|
|
#include "ck_tile/core/numeric/float8.hpp"
|
|
#include "ck_tile/core/numeric/float8_ext.hpp"
|
|
#include "ck_tile/core/numeric/half.hpp"
|
|
#include "ck_tile/core/numeric/int8.hpp"
|
|
#include "ck_tile/core/numeric/integer.hpp"
|
|
#include "ck_tile/core/numeric/integral_constant.hpp"
|
|
#include "ck_tile/core/numeric/math.hpp"
|
|
#include "ck_tile/core/numeric/math_v2.hpp"
|
|
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
|
#include "ck_tile/core/numeric/mxfp_scale.hpp"
|
|
#include "ck_tile/core/numeric/null_type.hpp"
|
|
#include "ck_tile/core/numeric/numeric.hpp"
|
|
#include "ck_tile/core/numeric/pk_f6.hpp"
|
|
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
|
#include "ck_tile/core/numeric/pk_int4.hpp"
|
|
#include "ck_tile/core/numeric/scale_util.hpp"
|
|
#include "ck_tile/core/numeric/type_convert.hpp"
|
|
#include "ck_tile/core/numeric/vector_type.hpp"
|
|
#include "ck_tile/core/tensor/buffer_view.hpp"
|
|
#include "ck_tile/core/tensor/load_tile.hpp"
|
|
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
|
#include "ck_tile/core/tensor/null_tensor.hpp"
|
|
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
|
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
|
#include "ck_tile/core/tensor/slice_tile.hpp"
|
|
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
|
#include "ck_tile/core/tensor/store_tile.hpp"
|
|
#include "ck_tile/core/tensor/sweep_tile.hpp"
|
|
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
|
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
|
#include "ck_tile/core/tensor/tensor_coordinate.hpp"
|
|
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
|
|
#include "ck_tile/core/tensor/tensor_view.hpp"
|
|
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
|
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
|
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
|
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
|
#include "ck_tile/core/tensor/tile_window.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
|
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
|
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
|
#include "ck_tile/core/tensor/update_tile.hpp"
|
|
#include "ck_tile/core/utility/bit_cast.hpp"
|
|
#include "ck_tile/core/utility/data_cache_prefetch.hpp"
|
|
#include "ck_tile/core/utility/debug.hpp"
|
|
#include "ck_tile/core/utility/env.hpp"
|
|
#include "ck_tile/core/utility/functional.hpp"
|
|
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
|
#include "ck_tile/core/utility/gemm_validation.hpp"
|
|
#include "ck_tile/core/utility/ignore.hpp"
|
|
#include "ck_tile/core/utility/literals.hpp"
|
|
#include "ck_tile/core/utility/magic_div.hpp"
|
|
#include "ck_tile/core/utility/mixed_prec_compute_type.hpp"
|
|
#include "ck_tile/core/utility/persistent_async_input_scheduler.hpp"
|
|
#include "ck_tile/core/utility/philox_rand.hpp"
|
|
#include "ck_tile/core/utility/print.hpp"
|
|
#include "ck_tile/core/utility/random.hpp"
|
|
#include "ck_tile/core/utility/reduce_operator.hpp"
|
|
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
|
#include "ck_tile/core/utility/static_counter.hpp"
|
|
#include "ck_tile/core/utility/to_sequence.hpp"
|
|
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
|
#include "ck_tile/core/utility/type_traits.hpp"
|
|
#include "ck_tile/core/utility/unary_element_function.hpp"
|