mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Wmma support for gemm_ab_scale (#3314)
* Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header
This commit is contained in:
@@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where
|
||||
// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and
|
||||
// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is
|
||||
/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires
|
||||
// an additional parameter KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); }
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is
|
||||
/// expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD_ABScale and
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK is that
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter
|
||||
/// KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_ABScaleSplitKWrapper
|
||||
: public DeviceGemmMultipleD_ABScale<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
|
||||
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
@@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
void, // AScaleType
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
@@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
0, // ScaleBlockM
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
@@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC_,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB_,
|
||||
nullptr,
|
||||
p_b_scale_grid_,
|
||||
k_batch_,
|
||||
a_element_op_,
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM, // scale block for M
|
||||
index_t ScaleBlockN, // scale block for N
|
||||
index_t ScaleBlockK, // scale block for K
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3
|
||||
: public DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
AScaleDataType,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
|
||||
{
|
||||
auto& arg = *dynamic_cast<Argument*>(base_arg);
|
||||
arg.KBatch = KBatch;
|
||||
arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const BScaleDataType* p_a_scale,
|
||||
const BScaleDataType* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch = 1) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemm_ABScale_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -18,52 +18,6 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
ComputeTypeB,
|
||||
true>; // IsBPreshuffle
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[Number<0>{}] =
|
||||
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[Number<0>{}] =
|
||||
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
|
||||
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
|
||||
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
|
||||
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
|
||||
// be odd.
|
||||
constexpr bool AtomicsImplementationExists =
|
||||
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
|
||||
std::is_same_v<EDataType, int8_t>) ||
|
||||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM, // scale block for M
|
||||
index_t ScaleBlockN, // scale block for N
|
||||
index_t ScaleBlockK, // scale block for K
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using AScaleLayout = tensor_layout::gemm::ColumnMajor;
|
||||
using BScaleLayout = BLayout;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
AScaleDataType,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
true,
|
||||
AScaleLayout,
|
||||
BScaleLayout>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
int GetPreShuffleParameters() override { return NPerWmma; }
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
true>; // IsBPreshuffle
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op,
|
||||
index_t KBatch)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -262,6 +262,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
@@ -279,22 +289,47 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
{
|
||||
auto& arg = *dynamic_cast<Argument*>(base_arg);
|
||||
arg.KBatch = KBatch;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
@@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
@@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
@@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
@@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
@@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
@@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
{
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
void, // AScaleType
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
@@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
0, // ScaleBlockM
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
@@ -207,7 +209,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB,
|
||||
nullptr,
|
||||
p_b_scale,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
@@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB,
|
||||
nullptr, // p_a_scale
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
|
||||
@@ -38,7 +38,8 @@ template <typename GridwiseGemm,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
typename ComputeTypeB,
|
||||
bool IsBPreShuffled = false>
|
||||
struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
|
||||
@@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Implement
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user