mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
[CK_TILE][GFX1250] Enable MX GEMM FLATMM with ASYNC ## Motivation Enables MX GEMM FLATMM pipeline on gfx1250. The pipeline uses an async load instruction for tensor A, which complements the existing MX GEMM FLATMM pipeline with TDM load. At this time, only FLATMM MX pipelines are enabled on gfx1250. ## Technical Details The existing gfx950 implementation was extended to support gfx1250 architecture. All three MX FP data types are supported across the two ASICs. It should be noted that while the TDM pipeline uses an emulated 32x32x128 warp-tile instruction, the present submission relies on the built-in 16x16x128 instruction, called 4 times per warp. ## Test Plan Existing `test/ck_tile/flatmm` tests were extended to cover new gfx1250 functionality. To help facilitate the testing in development, `example/ck_tile/18_flatmm/script/smoke_test_mx.sh` script was introduced to verify various combinations of supported data types and pipeline versions. ## Test Result The present submission is expected to work on both gfx950 and gfx1250 hardware for all reasonable sizes and all MX FP8/FP6/FP4 data types. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. - [x] Relies on #6978 and should only be merged after the changes are merged to the `develop`.
58 lines
2.2 KiB
C++
58 lines
2.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "mx_flatmm_instance.hpp"
|
|
|
|
// clang-format off
|
|
#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@
|
|
#define A_DATA_TYPE @A_DATA_TYPE@
|
|
#define B_DATA_TYPE @B_DATA_TYPE@
|
|
#define C_DATA_TYPE @C_DATA_TYPE@
|
|
#define A_LAYOUT @A_LAYOUT@
|
|
#define B_LAYOUT @B_LAYOUT@
|
|
#define C_LAYOUT @C_LAYOUT@
|
|
#define PERSISTENT @PERSISTENT@
|
|
#define SPLIT_K @SPLIT_K@
|
|
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
|
|
#define TAIL_NUMBER @TAIL_NUMBER@
|
|
// clang-format on
|
|
|
|
using FP4 = ck_tile::pk_fp4_t;
|
|
using FP8 = ck_tile::fp8_t;
|
|
using FP6 = ck_tile::pk_fp6x16_t;
|
|
using FP16 = ck_tile::fp16_t;
|
|
using BF16 = ck_tile::bf16_t;
|
|
|
|
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
|
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
|
|
|
using ScaleType = ck_tile::e8m0_t;
|
|
|
|
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
|
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
|
|
|
inline constexpr int ScaleGranularityM = 1;
|
|
inline constexpr int ScaleGranularityN = 1;
|
|
inline constexpr int ScaleGranularityK = 32;
|
|
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
|
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
|
|
|
template float mx_flatmm_calc<MXFLATMM_ARCH_TRAITS,
|
|
A_DATA_TYPE,
|
|
B_DATA_TYPE,
|
|
/*DsDatatype*/ ck_tile::tuple<>,
|
|
/*AccDataType*/ float,
|
|
C_DATA_TYPE,
|
|
A_LAYOUT,
|
|
B_LAYOUT,
|
|
/*DsLayout*/ ck_tile::tuple<>,
|
|
C_LAYOUT,
|
|
ScaleM,
|
|
ScaleN,
|
|
PERSISTENT,
|
|
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
|
|
SPLIT_K,
|
|
HAS_HOT_LOOP,
|
|
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
|
const ck_tile::stream_config& s);
|