mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4
* fix: removed extra parameter from grouped gemm example instance
* fix: FP8 check incorrectly enabling FP8 on RDNA3
[ROCm/composable_kernel commit: eb041079a3]
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
@@ -31,7 +32,7 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
|
||||
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
|
||||
// behaviour. When compiling WMMA instances and WMMA is supported, then we'll fail if a
|
||||
// specific case is not supported
|
||||
this->fail_if_no_supported_instances_ =
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
@@ -44,28 +45,31 @@ using KernelTypes = ::testing::Types<
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
|
||||
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Col, Col, Row, BF16, BF16, BF16>,
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_XDL) && defined(__gfx9__)
|
||||
#if defined(CK_USE_XDL) && !defined(CK_USE_WMMA)
|
||||
// XDL only at the moment, instances for WMMA not defined
|
||||
std::tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
// (And XDL instances don't run on gfx11/12, so we conditionally keep them out)
|
||||
ck::Tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
ck::Tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
#endif
|
||||
|
||||
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
|
||||
std::tuple< Row, Row, Row, F8, F16, F16>,
|
||||
std::tuple< Row, Row, Row, F16, F8, F16>,
|
||||
#if CK_USE_OCP_FP8 || CK_USE_FNUZ_FP8 || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_WMMA_FP8)
|
||||
// FP8 instances. Unfortunately CK_ENABLE_FP8 is always defined when not explicitly disabled, even if FP8 is
|
||||
// not supported for any included architecture.
|
||||
ck::Tuple< Row, Row, Row, F8, F16, F16>,
|
||||
ck::Tuple< Row, Row, Row, F16, F8, F16>,
|
||||
#endif
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F16>,
|
||||
ck::Tuple< Row, Row, Row, F16, F16, F16>,
|
||||
ck::Tuple< Row, Col, Row, F16, F16, F16>,
|
||||
ck::Tuple< Col, Row, Row, F16, F16, F16>,
|
||||
ck::Tuple< Col, Col, Row, F16, F16, F16>,
|
||||
|
||||
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, BF16>
|
||||
ck::Tuple< Row, Row, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Row, Col, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Col, Row, Row, BF16, BF16, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user