[rocm-libraries] ROCm/rocm-libraries#5095 (commit 7e55766)

[CK_TILE] Enable MXFP6 for MX GEMM op

## Motivation

Add support for MXFP6 in the MX GEMM op in CK-Tile.

Depends on https://github.com/ROCm/rocm-libraries/pull/4594

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-20 01:08:52 +00:00
committed by assistant-librarian[bot]
parent a5d0200ccf
commit d7c761e060
13 changed files with 160 additions and 31 deletions

View File

@@ -9,7 +9,8 @@ endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp6 test_mx_gemm_fp6.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp6 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -87,6 +87,13 @@ struct MXfp4_GemmConfig16 : MxGemmConfig
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp6_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp6Types = ::testing::Types<
std::tuple<ck_tile::pk_fp6x16_t, ck_tile::pk_fp6x16_t, MXfp6_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp6 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp6, MxFp6Types);
TYPED_TEST(TestMxGemmFp6, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -4,7 +4,6 @@
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/check_err.hpp"