Files
composable_kernel/test/ck_tile/flatmm/mx_flatmm_instance.cpp.in
Andriy Roshchenko d5c9215064 [rocm-libraries] ROCm/rocm-libraries#7359 (commit dd62f9f)
[CK_TILE][GFX1250] Enable MX GEMM FLATMM with ASYNC

## Motivation

Enables MX GEMM FLATMM pipeline on gfx1250. The pipeline uses an async
load instruction for tensor A, which complements the existing MX GEMM
FLATMM pipeline with TDM load. At this time, only FLATMM MX pipelines
are enabled on gfx1250.

## Technical Details

The existing gfx950 implementation was extended to support gfx1250
architecture. All three MX FP data types are supported across the two
ASICs.
It should be noted that while the TDM pipeline uses an emulated
32x32x128 warp-tile instruction, the present submission relies on the
built-in 16x16x128 instruction, called 4 times per warp.

## Test Plan

Existing `test/ck_tile/flatmm` tests were extended to cover new gfx1250
functionality.

To help facilitate the testing in development,
`example/ck_tile/18_flatmm/script/smoke_test_mx.sh` script was
introduced to verify various combinations of supported data types and
pipeline versions.

## Test Result

The present submission is expected to work on both gfx950 and gfx1250
hardware for all reasonable sizes and all MX FP8/FP6/FP4 data types.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
- [x] Relies on #6978 and should only be merged after the changes are
merged to the `develop`.
2026-05-29 17:02:45 +00:00

58 lines
2.2 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "mx_flatmm_instance.hpp"
// clang-format off
#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@
#define A_DATA_TYPE @A_DATA_TYPE@
#define B_DATA_TYPE @B_DATA_TYPE@
#define C_DATA_TYPE @C_DATA_TYPE@
#define A_LAYOUT @A_LAYOUT@
#define B_LAYOUT @B_LAYOUT@
#define C_LAYOUT @C_LAYOUT@
#define PERSISTENT @PERSISTENT@
#define SPLIT_K @SPLIT_K@
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
#define TAIL_NUMBER @TAIL_NUMBER@
// clang-format on
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP6 = ck_tile::pk_fp6x16_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
using ScaleType = ck_tile::e8m0_t;
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
inline constexpr int ScaleGranularityM = 1;
inline constexpr int ScaleGranularityN = 1;
inline constexpr int ScaleGranularityK = 32;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
template float mx_flatmm_calc<MXFLATMM_ARCH_TRAITS,
A_DATA_TYPE,
B_DATA_TYPE,
/*DsDatatype*/ ck_tile::tuple<>,
/*AccDataType*/ float,
C_DATA_TYPE,
A_LAYOUT,
B_LAYOUT,
/*DsLayout*/ ck_tile::tuple<>,
C_LAYOUT,
ScaleM,
ScaleN,
PERSISTENT,
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
SPLIT_K,
HAS_HOT_LOOP,
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);