mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Wmma support for multiple ABD GEMM (#2803)
* multi_abd wmma support: - Add multiple A and B support to multiple D implementation (gridwise level) - Add multi_abd GEMM (device level) - Add instances (xdl parity) - Add tests (both xdl and wmma) - Add examples - Add ckProfiler support (both xdl and wmma) * Fix bug in device print function * Fix unused template parameter * Fix batched gemm for multiABD gridwise implementation * Fix gemm_universal_reduce with multiABDs gridwise implementation --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -55,6 +55,155 @@ struct DeviceGemmMultipleABD : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
// input : A0[M, K], B0[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleABDSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
std::array<ck::index_t, NumATensor> StrideAs,
|
||||
std::array<ck::index_t, NumBTensor> StrideBs,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleABDSplitK in contexts where DeviceGemmMultipleABD is expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleABD and DeviceGemmMultipleABDSplitK
|
||||
/// is that DeviceGemmMultipleABDSplitK::MakeArgumentPointer requires an additional parameter
|
||||
/// KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleABDSplitKWrapper : public DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
|
||||
using DeviceOp = DeviceGemmMultipleABDSplitK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
|
||||
std::array<const void*, NumBTensor> p_bs,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
std::array<ck::index_t, NumATensor> StrideAs,
|
||||
std::array<ck::index_t, NumBTensor> StrideBs,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
1, // KBatch
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -64,9 +64,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
// shift A matrices pointer for splitk
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
|
||||
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
|
||||
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
@@ -278,8 +296,8 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
@@ -346,15 +364,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
: GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
|
||||
std::array<const void*, 1>{p_b_grid_},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
std::array<index_t, 1>{StrideA_},
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
@@ -423,26 +441,33 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
// Note: the grid descriptors and size_a / size_b do *not* take batching into
|
||||
// account, so we have to manually multiply overall buffer sizes for rotating
|
||||
// memory by batch.
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.Batch * size_a_buffer,
|
||||
arg_.Batch * size_b_buffer);
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
Tuple<>>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
std::array<std::size_t, 0>{});
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// E{M,N} = CDE_op(A_op(As{M,K}...) * B_op(Bs{K,N}...), Ds{M,N}...)
|
||||
/// Where As, Bs, Ds are input tensors and E is the output tensor. The A/B_op are
|
||||
/// elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
|
||||
/// to split the dot product accumulated over the K dimension into multiple working groups.
|
||||
/// The partial products of different workgroups are then reduced using the AtomicAdd
|
||||
/// operation.
|
||||
///
|
||||
/// @tparam AsLayout A tensors data layouts.
|
||||
/// @tparam BsLayout B tensors data layouts.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam AsDataType A tensors data types.
|
||||
/// @tparam BsDataType B tensors data types.
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultipleABD_Wmma_CShuffleV3
|
||||
: public DeviceGemmMultipleABDSplitK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
// Note: Pass multiple layout but then using only the first one
|
||||
// This is to replicate xdl functionality but it should be extended
|
||||
using ALayout = remove_cvref_t<tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::array<const void*, GridwiseGemm::NumATensor> p_as,
|
||||
std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
|
||||
std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
|
||||
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::array<const void*, GridwiseGemm::NumATensor> p_as,
|
||||
std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
|
||||
std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
|
||||
std::array<ck::index_t, GridwiseGemm::NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleABD_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", ";
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
|
||||
|
||||
str << std::string(ALayout_::name)[0];
|
||||
});
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
|
||||
|
||||
str << std::string(BLayout_::name)[0];
|
||||
});
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
str << std::string(DLayout::name)[0];
|
||||
});
|
||||
str << std::string(ELayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -193,8 +193,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -244,8 +244,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
MPerBlock,
|
||||
@@ -291,15 +291,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
@@ -328,15 +328,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
|
||||
@@ -182,8 +182,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
@@ -233,8 +233,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
@@ -283,15 +283,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch,
|
||||
@@ -317,15 +317,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch,
|
||||
|
||||
@@ -91,8 +91,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
@@ -144,8 +145,8 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
@@ -195,15 +196,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
@@ -233,15 +234,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
|
||||
@@ -23,8 +23,8 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t MPerBlock,
|
||||
@@ -88,15 +88,24 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
std::array<std::size_t, GridwiseGemm::NumATensor> size_as_buffers;
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
size_as_buffers[i] = a_grid_desc_ak0_m_ak1[i].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
});
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumBTensor> size_bs_buffers;
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1[i].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
@@ -108,12 +117,13 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
size_a_buffer,
|
||||
size_b_buffer,
|
||||
size_ds_buffers);
|
||||
ck::utility::
|
||||
RotatingMemWrapperMultiABD<Argument, AsDataType, BsDataType, DsDataType>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
|
||||
@@ -98,8 +98,8 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
GemmAccDataType,
|
||||
ReduceDataType,
|
||||
Tuple<>,
|
||||
@@ -147,15 +147,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
Argument(std::array<const void*, 1> p_a_grid_,
|
||||
std::array<const void*, 1> p_b_grid_,
|
||||
const ::std::array<const void*, NumDTensor> p_ds_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, 1> StrideA_,
|
||||
std::array<index_t, 1> StrideB_,
|
||||
const ::std::array<index_t, NumDTensor> stride_ds_,
|
||||
index_t StrideC_,
|
||||
index_t KBatch_,
|
||||
@@ -430,15 +430,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
stride_ds,
|
||||
StrideC,
|
||||
KBatch,
|
||||
@@ -472,15 +472,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return ::std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
return ::std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
DsStrides,
|
||||
StrideC,
|
||||
KSplit,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
@@ -39,8 +40,8 @@ namespace ck {
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam AsDataType A tensors data types.
|
||||
/// @tparam BsDataType B tensors data types.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
@@ -129,8 +130,8 @@ template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -181,8 +182,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -233,8 +234,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -305,8 +306,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::CalculateMPadded;
|
||||
using Base::CalculateNBlock;
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeAsGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBsGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
@@ -320,24 +321,30 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumATensor;
|
||||
using Base::NumBTensor;
|
||||
using Base::NumDTensor;
|
||||
using typename Base::AsGridPointer;
|
||||
using typename Base::BsGridPointer;
|
||||
using typename Base::DsGridPointer;
|
||||
using AsDataType_ = AsDataType;
|
||||
using BsDataType_ = BsDataType;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideAs{StrideAs_},
|
||||
StrideBs{StrideBs_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
KBatch{KBatch_},
|
||||
@@ -355,7 +362,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
<< "SAs: {";
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << "}, " << "SBs: {";
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << "}, ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
@@ -373,8 +388,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumATensor> StrideAs;
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t KBatch;
|
||||
@@ -391,15 +406,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
@@ -407,9 +422,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
: Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
|
||||
p_as_grid{},
|
||||
p_bs_grid{},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
@@ -417,9 +432,27 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
// populate pointer, desc for As
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
|
||||
// A pointer
|
||||
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
|
||||
});
|
||||
|
||||
// populate pointer, desc for Bs
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
|
||||
// B pointer
|
||||
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
|
||||
});
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
@@ -434,8 +467,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return (Problem::KBatch > 1) && (!is_reduce);
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
AsGridPointer p_as_grid;
|
||||
BsGridPointer p_bs_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
@@ -452,29 +485,39 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
// Note: in xdl implementation multiple AB supports one layout
|
||||
// but multiple strides, so we create an array of offsets with
|
||||
// the same values.
|
||||
// It should be fixed later on. Once we will have a thread transfer
|
||||
// more flexible.
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -497,8 +540,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
std::array<index_t, NumATensor> a_k_split_offset;
|
||||
std::array<index_t, NumBTensor> b_k_split_offset;
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
@@ -514,8 +557,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
@@ -524,10 +567,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -562,20 +605,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
TailNum>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
@@ -595,10 +638,26 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i];
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
BsGridPointer p_bs_grid_splitk;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
|
||||
@@ -22,8 +22,9 @@ template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename BScaleType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -76,8 +77,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -123,15 +124,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
PermuteA,
|
||||
PermuteB>
|
||||
{
|
||||
using BScaleType = ck::half_t;
|
||||
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -202,8 +201,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::CalculateMPadded;
|
||||
using Base::CalculateNBlock;
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeAsGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBsGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
@@ -217,7 +216,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumATensor;
|
||||
using Base::NumBTensor;
|
||||
using Base::NumDTensor;
|
||||
using typename Base::AsGridPointer;
|
||||
using typename Base::BsGridPointer;
|
||||
using typename Base::DsGridPointer;
|
||||
|
||||
struct Problem
|
||||
@@ -225,8 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
@@ -234,8 +237,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideAs{StrideAs_},
|
||||
StrideBs{StrideBs_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
@@ -254,7 +257,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
<< "SAs: {";
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << "}, " << "SBs: {";
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << "}, ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
@@ -273,8 +284,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumATensor> StrideAs;
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t StrideScaleB;
|
||||
@@ -292,15 +303,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
@@ -310,9 +321,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideAs_,
|
||||
StrideBs_,
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
StrideScaleB_,
|
||||
k_batch_},
|
||||
p_as_grid{},
|
||||
p_bs_grid{},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
@@ -321,6 +340,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
// populate pointer, desc for As
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
|
||||
// A pointer
|
||||
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
|
||||
});
|
||||
|
||||
// populate pointer, desc for Bs
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
|
||||
// B pointer
|
||||
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
@@ -338,8 +373,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
return (Problem::KBatch > 1) && (!is_reduce);
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
AsGridPointer p_as_grid;
|
||||
BsGridPointer p_bs_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
@@ -355,29 +390,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
// Note: in xdl implementation multiple AB supports one layout
|
||||
// but multiple strides, so we create an array of offsets with
|
||||
// the same values.
|
||||
// It should be fixed later on. Once we will have a thread transfer
|
||||
// more flexible.
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,8 +455,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
std::array<index_t, NumATensor> a_k_split_offset;
|
||||
std::array<index_t, NumBTensor> b_k_split_offset;
|
||||
index_t scale_k_split_offset; // New member for scale matrix offset
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
@@ -423,7 +468,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK, typename BScaleType>
|
||||
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
|
||||
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
index_t block_n_id)
|
||||
@@ -488,8 +533,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
@@ -499,10 +544,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -542,20 +587,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
TailNum>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
@@ -575,10 +620,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i];
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
BsGridPointer p_bs_grid_splitk;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_shared,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
@@ -61,8 +62,8 @@ template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -119,6 +120,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
|
||||
using LDSTypeA =
|
||||
typename std::conditional<(NumATensor > 1),
|
||||
ComputeTypeA,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
|
||||
using LDSTypeB =
|
||||
typename std::conditional<(NumBTensor > 1),
|
||||
ComputeTypeB,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
|
||||
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
@@ -136,14 +149,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
if constexpr(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
@@ -230,6 +243,31 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto MakeAsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
|
||||
return static_cast<const ADataType_*>(nullptr);
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
|
||||
static constexpr auto MakeBsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
|
||||
return static_cast<const BDataType_*>(nullptr);
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
|
||||
using AsGridPointer = decltype(MakeAsGridPointer());
|
||||
using BsGridPointer = decltype(MakeBsGridPointer());
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
{
|
||||
@@ -314,6 +352,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
|
||||
const index_t MPad,
|
||||
const index_t K,
|
||||
const index_t KPad,
|
||||
const std::array<index_t, NumATensor>& StrideAs,
|
||||
const index_t AK0)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0);
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
|
||||
{
|
||||
@@ -330,7 +383,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
static_assert(!(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> &&
|
||||
static_assert(!(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t> &&
|
||||
GemmSpec != GemmSpecialization::Default),
|
||||
"pk_i4_t does not support padding");
|
||||
|
||||
@@ -424,6 +477,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
|
||||
const index_t KPad,
|
||||
const index_t N,
|
||||
const index_t NPad,
|
||||
const std::array<index_t, NumBTensor>& StrideBs,
|
||||
const index_t BK0)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0);
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&)
|
||||
{
|
||||
@@ -557,7 +625,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize;
|
||||
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -604,20 +672,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
constexpr auto KThreadRead = 64 / MPerWmma;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
|
||||
? 1
|
||||
: 128 / (AK1Number * M0 * sizeof(ADataType));
|
||||
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=mpair<=n0
|
||||
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128)
|
||||
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128)
|
||||
? 1
|
||||
: ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0
|
||||
: ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0
|
||||
? M0
|
||||
: 128 / (AK1Number * MPerWmma * sizeof(ADataType)));
|
||||
: 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA)));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
@@ -694,7 +762,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
|
||||
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize;
|
||||
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -738,20 +806,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
constexpr auto KThreadRead = 64 / NPerWmma;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
|
||||
? 1
|
||||
: 128 / (BK1Number * N0 * sizeof(BDataType));
|
||||
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=n0
|
||||
constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128)
|
||||
constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128)
|
||||
? 1
|
||||
: ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0
|
||||
: ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0
|
||||
? N0
|
||||
: 128 / (BK1Number * NPerWmma * sizeof(BDataType)));
|
||||
: 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
@@ -836,8 +904,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
LDSTypeA,
|
||||
LDSTypeB,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
AccDataType,
|
||||
@@ -1120,11 +1188,24 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
|
||||
.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
|
||||
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
|
||||
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
template <index_t numElements, typename Type>
|
||||
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
|
||||
{
|
||||
if constexpr(numElements > 1)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
else
|
||||
{
|
||||
return array[I0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -1133,13 +1214,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
__device__ static void Run(AsGridPointer p_as_grid,
|
||||
BsGridPointer p_bs_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
@@ -1152,10 +1233,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& num_k_block_per_scale,
|
||||
BScaleStruct& b_scale_struct)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1183,66 +1274,144 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
// workaround because v7r2 is not as general as v4r1
|
||||
auto get_a_blockwise_transfer = [&]() {
|
||||
if constexpr(NumATensor > 1)
|
||||
{
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
|
||||
Number<NumATensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
Tuple<LDSTypeA>,
|
||||
AGridDesc_AK0_M_K1,
|
||||
decltype(tie(a_block_desc_ak0_m_ak1)),
|
||||
AElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
|
||||
idx_as_block_begin,
|
||||
tie(a_block_desc_ak0_m_ak1),
|
||||
make_tuple(make_multi_index(0, 0, 0)),
|
||||
a_element_op};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_v4r1<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>,
|
||||
decltype(as_grid_desc_ak0_m_ak1[I0]),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
as_grid_desc_ak0_m_ak1[I0],
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
}
|
||||
};
|
||||
|
||||
auto a_blockwise_copy = get_a_blockwise_transfer();
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
// workaround because v7r2 is not as general as v4r1
|
||||
auto get_b_blockwise_transfer = [&]() {
|
||||
if constexpr(NumBTensor > 1)
|
||||
{
|
||||
const auto idx_bs_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
|
||||
Number<NumBTensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
Tuple<LDSTypeB>,
|
||||
BGridDesc_BK0_N_K1,
|
||||
decltype(tie(b_block_desc_bk0_n_bk1)),
|
||||
BElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
|
||||
idx_bs_block_begin,
|
||||
tie(b_block_desc_bk0_n_bk1),
|
||||
make_tuple(make_multi_index(0, 0, 0)),
|
||||
b_element_op};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_v4r1<
|
||||
ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>,
|
||||
decltype(bs_grid_desc_bk0_n_bk1[I0]),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
bs_grid_desc_bk0_n_bk1[I0],
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
}
|
||||
};
|
||||
|
||||
auto b_blockwise_copy = get_b_blockwise_transfer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1250,12 +1419,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
sizeof(ADataType) /
|
||||
APackedSize),
|
||||
reinterpret_cast<LDSTypeB*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
sizeof(LDSTypeA) /
|
||||
APackedSize),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
@@ -1267,25 +1436,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
b_scale_struct,
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
get_first_element_workaround<NumATensor>(as_grid_buf),
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
get_first_element_workaround<NumBTensor>(bs_grid_buf),
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
b_scale_struct,
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user