mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#5095 (commit 7e55766)
[CK_TILE] Enable MXFP6 for MX GEMM op ## Motivation Add support for MXFP6 in the MX GEMM op in CK-Tile. Depends on https://github.com/ROCm/rocm-libraries/pull/4594 ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a5d0200ccf
commit
d7c761e060
@@ -9,7 +9,8 @@ endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp6 test_mx_gemm_fp6.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp6 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
|
||||
@@ -87,6 +87,13 @@ struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MXfp6_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
|
||||
30
test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp
Normal file
30
test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using MxFp6Types = ::testing::Types<
|
||||
std::tuple<ck_tile::pk_fp6x16_t, ck_tile::pk_fp6x16_t, MXfp6_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp6 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
|
||||
std::tuple_element_t<1, TypeParam>,
|
||||
std::tuple_element_t<2, TypeParam>,
|
||||
std::tuple_element_t<3, TypeParam>,
|
||||
std::tuple_element_t<4, TypeParam>,
|
||||
std::tuple_element_t<5, TypeParam>>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp6, MxFp6Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp6, BasicSizes)
|
||||
{
|
||||
this->Run(64, 64, 256);
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(64, 128, 512);
|
||||
}
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
|
||||
Reference in New Issue
Block a user