Files
composable_kernel/include/ck_tile/ops/flatmm.hpp
Andriy Roshchenko b8440b3aeb [rocm-libraries] ROCm/rocm-libraries#8325 (commit 559eaf6)
[GFX1250][MX GEMM] Unified FLATMM GroupedGemm Implementation
 for MX Data Types (#8325)

## Motivation

Design and test a unified FLATMM GroupedGemm interface so that it
supports all MX FP8, FP6, and FP4 data types on both the gfx950 and
gfx1250 architectures and works seamlessly across these platforms.

## Technical Details

Implementation exposes Grouped Gemm interface for MX FLATMM and MX TDM
FLATMM pipelines.

## Test Plan

Add the following tests:
 - ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_non_tdm.cpp
 - ck_tile/grouped_gemm_mx/test_grouped_gemm_mx_flatmm_tdm.cpp
 - ck_tile/flatmm/test_mx_flatmm_persistent.cpp

Verify on the gfx950 and gfx1250 architectures.

## Test Result

All tests pass. Verified on A0 hardware with rocm-7.14.0a20260517

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-15 16:12:33 +00:00

30 lines
1.8 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/grouped_mx_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"