mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add gemm_api and instances
This commit is contained in:
@@ -1,2 +1,27 @@
|
||||
# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
|
||||
function (add_gemm_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
list(APPEND INSTANCE_SRCS ${source})
|
||||
endforeach()
|
||||
|
||||
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
|
||||
endfunction(add_gemm_example TARGET_NAME MAIN_SRC)
|
||||
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS})
|
||||
|
||||
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -9,13 +9,10 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
@@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
return gemm_<Row, Row, Row>(args, s);
|
||||
}
|
||||
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
return gemm_<Row, Col, Row>(args, s);
|
||||
}
|
||||
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
return gemm_<Col, Row, Row>(args, s);
|
||||
}
|
||||
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
return gemm_<Col, Col, Row>(args, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Wrong! Layouts not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
@@ -51,6 +52,59 @@ using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
struct gemm_traits
|
||||
{
|
||||
std::string data_type;
|
||||
bool is_a_rowmajor;
|
||||
bool is_b_rowmajor;
|
||||
bool is_c_rowmajor;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
ck_tile::index_t M_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t K_Tile_,
|
||||
ck_tile::index_t M_Warp_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
ck_tile::index_t K_Warp_,
|
||||
ck_tile::index_t M_Warp_Tile_,
|
||||
ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_>
|
||||
struct gemm_traits_
|
||||
{
|
||||
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
|
||||
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
|
||||
using AccDataType = ck_tile::remove_cvref_t<AccDataType_>;
|
||||
using CDataType = ck_tile::remove_cvref_t<CDataType_>;
|
||||
using ALayout = ck_tile::remove_cvref_t<ALayout_>;
|
||||
using BLayout = ck_tile::remove_cvref_t<BLayout_>;
|
||||
using CLayout = ck_tile::remove_cvref_t<CLayout_>;
|
||||
static constexpr ck_tile::index_t M_Tile = M_Tile_;
|
||||
static constexpr ck_tile::index_t N_Tile = N_Tile_;
|
||||
static constexpr ck_tile::index_t K_Tile = K_Tile_;
|
||||
static constexpr ck_tile::index_t M_Warp = M_Warp_;
|
||||
static constexpr ck_tile::index_t N_Warp = N_Warp_;
|
||||
static constexpr ck_tile::index_t K_Warp = K_Warp_;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
template <typename Traits_>
|
||||
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
float gemm(const gemm_traits& traits,
|
||||
const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
482
example/ck_tile/03_gemm/instances/gemm_api.cpp
Normal file
482
example/ck_tile/03_gemm/instances/gemm_api.cpp
Normal file
@@ -0,0 +1,482 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using FP32 = float;
|
||||
using FP16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
ck_tile::index_t M_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t K_Tile_,
|
||||
ck_tile::index_t M_Warp_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
ck_tile::index_t K_Warp_,
|
||||
ck_tile::index_t M_Warp_Tile_,
|
||||
ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_>
|
||||
using trait_ = gemm_traits_<ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
M_Tile_,
|
||||
N_Tile_,
|
||||
K_Tile_,
|
||||
M_Warp_,
|
||||
N_Warp_,
|
||||
K_Warp_,
|
||||
M_Warp_Tile_,
|
||||
N_Warp_Tile_,
|
||||
K_Warp_Tile_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
kPadK_>;
|
||||
|
||||
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound RR
|
||||
std::cout << "fp16 comp\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound RR
|
||||
std::cout << "fp16 mem\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound RC
|
||||
std::cout << "fp16 comp RC\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound RC
|
||||
std::cout << "fp16 mem RC\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound CR
|
||||
std::cout << "fp16 comp CR\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound CR
|
||||
std::cout << "fp16 mem CR\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound CC
|
||||
std::cout << "fp16 comp CC\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound CC
|
||||
std::cout << "fp16 mem CC\n";
|
||||
return gemm_<trait_<FP16,
|
||||
FP16,
|
||||
FP32,
|
||||
FP16,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
|
||||
}
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound RR
|
||||
std::cout << "bf16 comp\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound RR
|
||||
std::cout << "bf16 mem\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound RC
|
||||
std::cout << "bf16 comp RC\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound RC
|
||||
std::cout << "bf16 mem RC\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound CR
|
||||
std::cout << "bf16 comp CR\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound CR
|
||||
std::cout << "bf16 mem CR\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
|
||||
{
|
||||
if(a.M > 512)
|
||||
{
|
||||
// universal gemm compute bound CC
|
||||
std::cout << "bf16 comp CC\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// universal gemm memory bound CC
|
||||
std::cout << "bf16 mem CC\n";
|
||||
return gemm_<trait_<BF16,
|
||||
BF16,
|
||||
FP32,
|
||||
BF16,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(a, s);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Wrong! DataTypes not supported!\n");
|
||||
}
|
||||
|
||||
return 1.0f;
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_comp_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <iostream>
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
using A = ck_tile::GemmHostArgs;
|
||||
using S = ck_tile::stream_config;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
ck_tile::index_t M_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t K_Tile_,
|
||||
ck_tile::index_t M_Warp_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
ck_tile::index_t K_Warp_,
|
||||
ck_tile::index_t M_Warp_Tile_,
|
||||
ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_>
|
||||
using trait_ = gemm_traits_<ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
M_Tile_,
|
||||
N_Tile_,
|
||||
K_Tile_,
|
||||
M_Warp_,
|
||||
N_Warp_,
|
||||
K_Warp_,
|
||||
M_Warp_Tile_,
|
||||
N_Warp_Tile_,
|
||||
K_Warp_Tile_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
kPadK_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
|
||||
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
|
||||
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType,
|
||||
typename Traits_::CDataType,
|
||||
Traits_::kPadM,
|
||||
Traits_::kPadN>>;
|
||||
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
|
||||
Traits_::kPadN,
|
||||
Traits_::kPadK,
|
||||
typename Traits_::ALayout,
|
||||
typename Traits_::BLayout,
|
||||
typename Traits_::CLayout>;
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
|
||||
ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
|
||||
typename Traits_::BDataType,
|
||||
typename Traits_::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits>>;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
|
||||
ck_tile::UniversalGemmPipelineProblem<typename Traits_::ADataType,
|
||||
typename Traits_::BDataType,
|
||||
typename Traits_::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
static_assert(BaseGemmPipeline::PrefetchStages > 3);
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always Full - #PrefetchStages
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "When there's no hot loop, this tail number \"" << tail_num
|
||||
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
ck_tile::bf16_t,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_universal_mem_instance_common.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template float gemm_<trait_<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
128,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
false>>(const A&, const S&);
|
||||
@@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <iostream>
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
using A = ck_tile::GemmHostArgs;
|
||||
using S = ck_tile::stream_config;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
ck_tile::index_t M_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t K_Tile_,
|
||||
ck_tile::index_t M_Warp_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
ck_tile::index_t K_Warp_,
|
||||
ck_tile::index_t M_Warp_Tile_,
|
||||
ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_>
|
||||
using trait_ = gemm_traits_<ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
M_Tile_,
|
||||
N_Tile_,
|
||||
K_Tile_,
|
||||
M_Warp_,
|
||||
N_Warp_,
|
||||
K_Warp_,
|
||||
M_Warp_Tile_,
|
||||
N_Warp_Tile_,
|
||||
K_Warp_Tile_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
kPadK_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
|
||||
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
|
||||
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType,
|
||||
typename Traits_::CDataType,
|
||||
Traits_::kPadM,
|
||||
Traits_::kPadN>>;
|
||||
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
|
||||
Traits_::kPadN,
|
||||
Traits_::kPadK,
|
||||
typename Traits_::ALayout,
|
||||
typename Traits_::BLayout,
|
||||
typename Traits_::CLayout>;
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
|
||||
typename Traits_::BDataType,
|
||||
typename Traits_::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits>>;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
|
||||
ck_tile::UniversalGemmPipelineProblem<typename Traits_::ADataType,
|
||||
typename Traits_::BDataType,
|
||||
typename Traits_::AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
ck_tile::GemmPipelineScheduler::Interwave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
static_assert(BaseGemmPipeline::PrefetchStages > 3);
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always Full - #PrefetchStages
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "When there's no hot loop, this tail number \"" << tail_num
|
||||
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -28,8 +28,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
gemm_traits traits{DataTypeTraits<ADataType>{}.name,
|
||||
std::is_same_v<ALayout, Row>,
|
||||
std::is_same_v<BLayout, Row>,
|
||||
std::is_same_v<CLayout, Row>};
|
||||
|
||||
float ave_time =
|
||||
gemm(traits, args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -210,9 +215,6 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
@@ -224,16 +226,14 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
|
||||
// work.
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(a_layout == "C" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
// }
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -14,207 +14,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
|
||||
#endif
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// ===============================================
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
|
||||
#endif
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
|
||||
#endif
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
Traits,
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
ck_tile::GemmPipelineScheduler::Interwave,
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
#endif
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always Full - #PrefetchStages
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "When there's no hot loop, this tail number \"" << tail_num
|
||||
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
Reference in New Issue
Block a user