Implement grouped gemm tile loop for RDNA4 (#3304)

* feat: grouped gemm tile loop support for RDNA4

* fix: removed extra parameter from grouped gemm example instance

* fix: FP8 check incorrectly enabling FP8 on RDNA3

[ROCm/composable_kernel commit: eb041079a3]
This commit is contained in:
Erwin Terpstra
2026-01-13 07:14:23 +01:00
committed by GitHub
parent 0d13ef7329
commit d69aeffd0d
44 changed files with 3067 additions and 1223 deletions

View File

@@ -1,11 +1,12 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
@@ -31,7 +32,7 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
#if defined(CK_USE_WMMA)
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
// behaviour. When compiling WMMA instances and WMMA is supported, then we'll fail if a
// specific case is not supported
this->fail_if_no_supported_instances_ =
ck::is_gfx11_supported() || ck::is_gfx12_supported();
@@ -44,28 +45,31 @@ using KernelTypes = ::testing::Types<
#if defined(CK_USE_WMMA)
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
ck::Tuple< Col, Col, Row, BF16, BF16, BF16>,
#endif
#if defined(CK_USE_XDL) && defined(__gfx9__)
#if defined(CK_USE_XDL) && !defined(CK_USE_WMMA)
// XDL only at the moment, instances for WMMA not defined
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
// (And XDL instances don't run on gfx11/12, so we conditionally keep them out)
ck::Tuple< Row, Row, Row, BF16, I8, BF16>,
ck::Tuple< Row, Col, Row, BF16, I8, BF16>,
#endif
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
std::tuple< Row, Row, Row, F8, F16, F16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
#if CK_USE_OCP_FP8 || CK_USE_FNUZ_FP8 || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_WMMA_FP8)
// FP8 instances. Unfortunately CK_ENABLE_FP8 is always defined when not explicitly disabled, even if FP8 is
// not supported for any included architecture.
ck::Tuple< Row, Row, Row, F8, F16, F16>,
ck::Tuple< Row, Row, Row, F16, F8, F16>,
#endif
std::tuple< Row, Row, Row, F16, F16, F16>,
std::tuple< Row, Col, Row, F16, F16, F16>,
std::tuple< Col, Row, Row, F16, F16, F16>,
std::tuple< Col, Col, Row, F16, F16, F16>,
ck::Tuple< Row, Row, Row, F16, F16, F16>,
ck::Tuple< Row, Col, Row, F16, F16, F16>,
ck::Tuple< Col, Row, Row, F16, F16, F16>,
ck::Tuple< Col, Col, Row, F16, F16, F16>,
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
std::tuple< Col, Row, Row, BF16, BF16, BF16>
ck::Tuple< Row, Row, Row, BF16, BF16, BF16>,
ck::Tuple< Row, Col, Row, BF16, BF16, BF16>,
ck::Tuple< Col, Row, Row, BF16, BF16, BF16>
>;
// clang-format on