mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK-Tile] Universal gemm memory bound pipeline (#1558)
* CK-Tile GEMM with memory bound pipeline. * Memory bound gemm pipeline. * Fix not closed namespace. * Block gemm mem pipeline draft. * Do not use ck_tile:: within ck_tile namespace. * Refactoring & Move Layout info to pipeline problem. * Get hot loop and TailNum information before lunching kernel. * Fixes in pipeline. * Add comment to load_tile_raw and change variable naming style. * Few small changes & formatting. * Do not use macro. * Add gtests. * Use AccDataType for Output of MFMA instruction. * Formatting. * Refactor gemm examples. * Switch over to current block gemm. * Use currently available pipeline policy. * Refactoring and review comment.s * Fixes after merge. * Add missing include. * Add load tile overload which accepts output tensor as parameter. * This give 8% perf boost at the cost of using more registers. * Rename example. * Small changes. * Fix compilation err and lower K. * Support different layouts for A/B * Fix vector size for different layouts. * Rename Alignment into VectorSize * Unblock tests.
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,55 +15,36 @@ template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_n_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_n_k.mDesc.get_lengths()[0]
|
||||
: b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_m_k.mDesc.get_lengths()[1]
|
||||
: a_m_k.mDesc.get_lengths()[0];
|
||||
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_m_k.mDesc.get_lengths()[0]
|
||||
: a_m_k.mDesc.get_lengths()[1];
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f = [&](auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
ADataType v_a = a_element_op(a_m_k(m, k));
|
||||
BDataType v_b = b_element_op(b_k_n(k, n));
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_element_op(a_m_k(m, k))
|
||||
: a_element_op(a_m_k(k, m));
|
||||
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_element_op(b_n_k(n, k))
|
||||
: b_element_op(b_n_k(k, n));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? c_m_n(m, n)
|
||||
: c_m_n(n, m);
|
||||
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
|
||||
Reference in New Issue
Block a user