mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK_TILE] Multiple-ABD GEMM example (#2788)
* Multi ABD - initial commit
* Clang-foramt fix
* block gemm, unify the name of CDataType
* Apply chnages to mem-pipeline
* Rollback prefix for DType and Layout
* Gemm Kernel Basic, rename
* WMMA config
* Grouped GEMM
* Clang-format
* Dropout, name
* Review v2
* Move element_wise fn to unnary, remov old ones fn
* clang-format
* Fix issue review
* WP operator adjust to universal gemm
* v2 prepare
* Remove unused comment
* Remove vectorsize
* Rollback
* Adjust pipeline for abd
* Shuffle argument
* CI-fail fix quant
* Fix ag_br pipeline
* Failing tests
* Typo
* Single argument support
[ROCm/composable_kernel commit: 30ab1d6a71]
This commit is contained in:
@@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note This function is a modification of the existing load function.
|
||||
* It has been extended with two additional parameters: it takes a tuple as input
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should works with unknow number of params
|
||||
// Load element_wise API works only when the input typle is a tuple-tyupe
|
||||
return tile_window[number<0>{}].load(
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
@@ -120,6 +120,116 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load tile with elementwise function
|
||||
*
|
||||
* @note Load tile with elementwise — during value loading, an
|
||||
* elementwise function is executed for each A0, A1, … AN.
|
||||
* The values A0, A1, … AN are read by the same thread. In this way, we
|
||||
* reduce the amount of information loaded into the registers.
|
||||
* The same thread, during vectorized reading, accesses the same set of
|
||||
* data from A0, A1, A2, … AN.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_window,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
const TileWindow_& tile_window,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
constexpr auto sizeOfTuple = TileWindow_::size();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const auto idx_vec_value = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return tile_window[number<jj>{}]
|
||||
.get_bottom_tensor_view()
|
||||
.template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
},
|
||||
number<sizeOfTuple>{});
|
||||
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
elementwise(dst_tensor.get_thread_buffer().template at<d>(),
|
||||
t.template get_as<
|
||||
typename Base::DataType>()[j / Traits::PackedSize]...);
|
||||
},
|
||||
idx_vec_value);
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
@@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
using T = tuple<tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>>;
|
||||
|
||||
static constexpr auto N = T::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
template <typename TileWindowWithStaticDistributionType,
|
||||
typename StepType,
|
||||
typename std::enable_if_t<
|
||||
is_detected<is_tuple, TileWindowWithStaticDistributionType>::value>* = nullptr>
|
||||
CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
|
||||
{
|
||||
static constexpr auto N = TileWindowWithStaticDistributionType::size();
|
||||
static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
|
||||
@@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDElementOp,
|
||||
typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
|
||||
typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
|
||||
const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<ADataType>& a_m_k,
|
||||
HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const CDElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto as_m_k_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
|
||||
|
||||
auto bs_k_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
|
||||
|
||||
auto ds_m_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
|
||||
|
||||
// Apply elementwise function to A
|
||||
auto a_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
|
||||
|
||||
// Apply elementwise function to B
|
||||
auto b_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(t(m, n))...);
|
||||
},
|
||||
ds_m_n_tuple);
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
|
||||
@@ -392,6 +392,23 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
template <typename E, typename... As>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
|
||||
{
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(0.0f);
|
||||
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(as)), ...);
|
||||
|
||||
a = ck_tile::type_convert<E>(scale * result);
|
||||
}
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
|
||||
@@ -28,8 +28,8 @@ struct GetDataType<T>
|
||||
using type = typename T::DataType; // Use T::ScaleN::DataType
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
@@ -51,8 +51,8 @@ template <typename ADataType_,
|
||||
bool TiledMMAPermuteN_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
@@ -83,12 +83,27 @@ template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
|
||||
@@ -28,8 +28,8 @@ struct Default2DEpilogueProblem
|
||||
static constexpr index_t NumDTensor = 0;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
@@ -53,8 +53,8 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
@@ -157,14 +157,28 @@ struct Default2DEpilogue
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
|
||||
@@ -90,10 +90,10 @@ struct BatchedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
|
||||
@@ -89,7 +89,7 @@ struct GemmKernel
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -106,10 +106,10 @@ struct GemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
static constexpr index_t NumATensor = 1;
|
||||
static constexpr index_t NumBTensor = 1;
|
||||
|
||||
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The MultiABD GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating
|
||||
/// kernel arguments object. It contain all necessary information required to build proper
|
||||
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
|
||||
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
|
||||
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
|
||||
/// required).
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct GemmMultiABDHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_As(stride_As_),
|
||||
stride_Bs(stride_Bs_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<const void*, NumATensor> as_ptr;
|
||||
const std::array<const void*, NumBTensor> bs_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const std::array<index_t, NumATensor> stride_As;
|
||||
const std::array<index_t, NumBTensor> stride_Bs;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernelMultiABD
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using AsLayout = remove_cvref_t<typename GemmPipeline::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename GemmPipeline::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, AsLayout>::value &&
|
||||
is_detected<is_tuple, AsDataType>::value,
|
||||
"ALayout and ADataType must be a tuple.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, BsLayout>::value &&
|
||||
is_detected<is_tuple, BsDataType>::value,
|
||||
"BLayout and BDataType must be a tuple.");
|
||||
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"CLayout and EDataType must be a scalar.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
is_detected<is_tuple, DsDataType>::value &&
|
||||
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
|
||||
"DsLayout and DsDataType must be tuples and must have the same size.");
|
||||
|
||||
/// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user."
|
||||
static constexpr index_t NumATensor = AsDataType::size();
|
||||
static constexpr index_t NumBTensor = BsDataType::size();
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
return UniversalGemmKernel::GetName();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const GemmMultiABDHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs) ->
|
||||
typename UniversalGemmKernel::KernelArgs
|
||||
{
|
||||
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
||||
/// matrices A, B, and D.
|
||||
return UniversalGemmKernel::MakeKernelArgs(
|
||||
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>(hostArgs.as_ptr,
|
||||
hostArgs.bs_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_As,
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
// Currently MultiABD kernel doesn't support k_batch > 1
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
||||
{
|
||||
UniversalGemmKernel{}.template operator()(kargs);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -95,7 +95,7 @@ struct GemmKernelMultiD
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
@@ -114,10 +114,10 @@ struct GemmKernelMultiD
|
||||
!is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"ELayout and EDataType must be scalars.");
|
||||
"CLayout and EDataType must be scalars.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
|
||||
@@ -120,10 +120,10 @@ struct GroupedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
@@ -364,12 +364,8 @@ struct GroupedGemmKernel
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
|
||||
@@ -157,23 +157,23 @@ struct UniversalGemmKernel
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ADataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BDataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
|
||||
static constexpr bool DDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
|
||||
static constexpr bool ALayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ALayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BLayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
|
||||
static constexpr bool DLayoutIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
|
||||
|
||||
using AsLayout = std::conditional_t<ALayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ALayout>,
|
||||
remove_cvref_t<typename GemmPipeline::AsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
|
||||
using BsLayout = std::conditional_t<BLayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BLayout>,
|
||||
remove_cvref_t<typename GemmPipeline::BsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
|
||||
|
||||
using DsLayout = std::conditional_t<DLayoutIsTuple,
|
||||
@@ -181,11 +181,11 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
|
||||
|
||||
using AsDataType = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ADataType>,
|
||||
remove_cvref_t<typename GemmPipeline::AsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
|
||||
|
||||
using BsDataType = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BDataType>,
|
||||
remove_cvref_t<typename GemmPipeline::BsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
|
||||
|
||||
using DsDataType =
|
||||
@@ -193,9 +193,12 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
@@ -483,7 +486,7 @@ struct UniversalGemmKernel
|
||||
bool DTesnorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, ELayout> == false)
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
@@ -529,7 +532,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
@@ -724,7 +727,7 @@ struct UniversalGemmKernel
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
@@ -818,7 +821,7 @@ struct UniversalGemmKernel
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -975,8 +978,8 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
@@ -1031,8 +1034,13 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
BElementWise{},
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
@@ -11,12 +11,17 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmPipelineAgBgCrImplBase
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
@@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase
|
||||
store_tile(lds_tile_window, block_tile_tmp);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile) const
|
||||
{
|
||||
store_tile(lds_tile_window, src_block_tile);
|
||||
}
|
||||
|
||||
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
|
||||
const SrcTileWindow& lds_tile_window,
|
||||
@@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase
|
||||
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
a_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_lds_shape = []() {
|
||||
@@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
const BLdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
b_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
// A DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// TODO: Do we really need those two tile windows???
|
||||
// They're exactly same...
|
||||
|
||||
@@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
@@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -449,17 +466,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
@@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// global read 1
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
@@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline by wrapping it with
|
||||
* the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline using compile-time
|
||||
* known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
@@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -279,29 +296,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// A register tile for global load
|
||||
constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution();
|
||||
constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution();
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
@@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
// global prefetch 0
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
////////////// LDS desc, window & register /////////////////
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
@@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Generating a tuple with tile_windows for values A0, A1, ... AN
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
// Generating a tuple with tile_windows for values B0, B1, ... BN
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
// global read 1
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto ALdsTileDistr =
|
||||
@@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
if(HasHotLoop)
|
||||
{
|
||||
@@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
HotLoopScheduler();
|
||||
@@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window1, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
HotLoopScheduler();
|
||||
@@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
@@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
@@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
@@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -209,14 +226,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
ABlockTile elementwise_As_res;
|
||||
BBlockTile elementwise_Bs_res;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
@@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
@@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
@@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
@@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -310,8 +327,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
@@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
@@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -600,8 +639,10 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
@@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}));
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
@@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
a_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
b_block_tiles.at(number<prefetch_idx>{}) =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
a_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
b_block_tiles.get(number<prefetch_idx>{}));
|
||||
}
|
||||
});
|
||||
|
||||
@@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
@@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
@@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
@@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -15,14 +15,23 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
@@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
@@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto as_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
auto bs_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
@@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
@@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, a_block_tile);
|
||||
store_tile(a_copy_lds_window,
|
||||
tile_elementwise_in(a_element_func, a_shuffle_tmp_loop));
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
@@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, b_block_tile);
|
||||
store_tile(b_copy_lds_window,
|
||||
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
@@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType & b) { return b; },
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -15,30 +15,66 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize));
|
||||
// clang-format on
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
@@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
@@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
auto as_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
as_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
auto bs_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
bs_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
@@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write 0
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
@@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read i + 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write i + 1
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read i + 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
iCounter--;
|
||||
|
||||
@@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType & b) { return b; },
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,16 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
@@ -22,18 +25,49 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
@@ -66,7 +100,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
@@ -84,7 +118,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
@@ -125,7 +159,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
return VectorSizeA_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
@@ -140,7 +174,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
return VectorSizeB_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
@@ -161,35 +195,40 @@ struct GemmPipelineProblemBase
|
||||
}();
|
||||
};
|
||||
|
||||
// Alias for GemmPipelineProblem
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<AsDataType_,
|
||||
BsDataType_,
|
||||
EDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_,
|
||||
AElementWise_,
|
||||
BElementWise_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
@@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
|
||||
@@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
@@ -382,11 +385,14 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
@@ -482,8 +488,6 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
@@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
|
||||
@@ -10,8 +10,8 @@ namespace ck_tile {
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmTraits
|
||||
@@ -23,9 +23,9 @@ struct TileGemmTraits
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
@@ -36,8 +36,8 @@ template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
@@ -52,9 +52,9 @@ struct TileGemmUniversalTraits
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
@@ -67,8 +67,8 @@ template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
@@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
AsLayout_,
|
||||
BsLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
|
||||
@@ -37,15 +37,24 @@ template <typename Problem, typename PipelinePolicy = UniversalWeightPreshuffleP
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
@@ -188,7 +197,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -455,7 +469,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
@@ -463,7 +503,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
|
||||
@@ -53,14 +53,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
@@ -502,7 +511,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
template <TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction>
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -1001,8 +1013,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// called from universal gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
@@ -1019,9 +1060,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
|
||||
@@ -44,6 +44,10 @@ struct TileGemmQuantTraits
|
||||
using AQLayout = AQLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
// TODO: It should be replaced to single value
|
||||
using AsLayout = ALayout_;
|
||||
using BsLayout = BLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
Reference in New Issue
Block a user