mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
[GFX1250][MX GEMM] Unified FLATMM GroupedGemm Implementation for MX Data Types (#8325) ## Motivation Design and test a unified FLATMM GroupedGemm interface so that it supports all MX FP8, FP6, and FP4 data types on both the gfx950 and gfx1250 architectures and works seamlessly across these platforms. ## Technical Details Implementation exposes Grouped Gemm interface for MX FLATMM and MX TDM FLATMM pipelines. ## Test Plan Add the following tests: - ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_non_tdm.cpp - ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_tdm.cpp - ck_tile/flatmm/test_mx_flatmm_persistent.cpp Verify on the gfx950 and gfx1250 architectures. ## Test Result All tests pass. Verified on A0 hardware with rocm-7.14.0a20260517 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
30 lines
1.8 KiB
C++
30 lines
1.8 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
#pragma once
|
|
|
|
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
|
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
|
|
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
|
|
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
|
|
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/grouped_mx_flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
|
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
|
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
|
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
|
#include "ck_tile/ops/common/streamk_common.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/common/utils.hpp"
|