mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Cleanup and refactoring related to tile loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Cleanup and refactoring done while implementing mixed precision for fp16/bf16 x fp8 Key changes: - Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and refactored the API to use consistent naming conventions - Updated load_tile_transpose functions to use output parameters instead of return values for consistency - Removed unused variable declarations and simplified type deduction logic - Define load_tile_with_elementwise to use tuple types explicitly for clarity ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [X] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
86 lines
5.7 KiB
C++
86 lines
5.7 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_async_v1.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
|
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
|
#include "ck_tile/ops/common/streamk_common.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/common/utils.hpp"
|