[CK-TILE] Default epilogue, adding support for D (#2629)

* Extend 2d-epilogue, D support

* Added tests & update

* Remove unused attribute

* Extend tests

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Mateusz Ozga
2025-08-26 04:29:35 +02:00
committed by GitHub
parent 99d27aca17
commit d43228fbca
10 changed files with 624 additions and 428 deletions

View File

@@ -5,6 +5,8 @@ if(CK_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp)
target_compile_definitions(test_ck_tile_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp)
add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp)
target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_definitions(test_gemm_multi_d_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -18,22 +18,23 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>,
// Has cshuffle epilogue enabled
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply>
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::true_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply, std::true_type>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes);
#include "test_gemm_multi_d_ut_cases.inc"
#include "test_gemm_multi_d_ut_cases_cshuffle.inc"

View File

@@ -0,0 +1,43 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_multi_d_util.hpp"
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// Has cshuffle epilogue disabled
// ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, CDElementWiseFn, UseCshuffleEpilog
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, ElementWiseAddAdd, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, BF16, MultiplyMultiply, std::false_type>,
std::tuple< Row, Col, Row, Row, Row, F8, F8, F16, F16, F32, F16, MultiplyMultiply, std::false_type>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes);
#include "test_gemm_multi_d_ut_cases_default2d.inc"

View File

@@ -1,334 +0,0 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
this->Run(M, N, K, kBatch);
}

View File

@@ -0,0 +1,211 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1CShuffle_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x512)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2CShuffle_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}

View File

@@ -0,0 +1,211 @@
#pragma once
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch1Default_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 256;
constexpr int kBatch = 1;
EXPECT_EQ(this->Run(M, N, K, kBatch), true);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x512)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x512x256)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x256x256)
{
constexpr int M = 512;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x512x256)
{
constexpr int M = 512;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x256x256)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x768x256)
{
constexpr int M = 512;
constexpr int N = 768;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_512x1280x256)
{
constexpr int M = 512;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_256x1280x256)
{
constexpr int M = 256;
constexpr int N = 1280;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_768x512x256)
{
constexpr int M = 768;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x512x256)
{
constexpr int M = 1280;
constexpr int N = 512;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}
TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDKBatch2Default_1280x256x256)
{
constexpr int M = 1280;
constexpr int N = 256;
constexpr int K = 512;
constexpr int kBatch = 2;
EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error);
}

View File

@@ -70,20 +70,21 @@ template <typename Tuple>
class TestCkTileGemmMultiD : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using D0Layout = std::tuple_element_t<2, Tuple>;
using D1Layout = std::tuple_element_t<3, Tuple>;
using ELayout = std::tuple_element_t<4, Tuple>;
using ADataType = std::tuple_element_t<5, Tuple>;
using BDataType = std::tuple_element_t<6, Tuple>;
using D0DataType = std::tuple_element_t<7, Tuple>;
using D1DataType = std::tuple_element_t<8, Tuple>;
using AccDataType = std::tuple_element_t<9, Tuple>;
using EDataType = std::tuple_element_t<10, Tuple>;
using CDElementWiseFn = std::tuple_element_t<11, Tuple>;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using D0Layout = std::tuple_element_t<2, Tuple>;
using D1Layout = std::tuple_element_t<3, Tuple>;
using ELayout = std::tuple_element_t<4, Tuple>;
using ADataType = std::tuple_element_t<5, Tuple>;
using BDataType = std::tuple_element_t<6, Tuple>;
using D0DataType = std::tuple_element_t<7, Tuple>;
using D1DataType = std::tuple_element_t<8, Tuple>;
using AccDataType = std::tuple_element_t<9, Tuple>;
using EDataType = std::tuple_element_t<10, Tuple>;
using CDElementWiseFn = std::tuple_element_t<11, Tuple>;
using UseCshuffleEpilog = std::tuple_element_t<12, Tuple>;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
template <typename ADataType,
typename BDataType,
@@ -169,7 +170,28 @@ class TestCkTileGemmMultiD : public ::testing::Test
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
@@ -188,6 +210,9 @@ class TestCkTileGemmMultiD : public ::testing::Test
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = std::
conditional_t<UseCshuffleEpilog::value, CShuffleGemmEpilogue, DefaultGemmEpilogue>;
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -218,6 +243,7 @@ class TestCkTileGemmMultiD : public ::testing::Test
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
std::cout << "Run without SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
@@ -225,42 +251,19 @@ class TestCkTileGemmMultiD : public ::testing::Test
}
else
{
std::cout << "Run using SplitK" << std::endl;
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
if(has_hot_loop)
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \""
<< tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
void Run(const int M,
bool Run(const int M,
const int N,
const int K,
const int k_batch,
@@ -401,6 +404,6 @@ class TestCkTileGemmMultiD : public ::testing::Test
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
EXPECT_TRUE(pass);
return pass;
}
};