mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Universal gemm splitk using reduce (with multi-d) (#1341)
* init for reduce_threadwise multi_d
* add reduce_threadwise_multi_d
* add reduce_multi_d
* clean
* start add an other splitk device op
* add reduce template parameter to SplitKBatchOffset
* add reduce c matrix
* clean up code
* change example data type to bf16
* add bf16Ai8B example
* remove reduce template parameter
* add splitk atomic status to v4
* example add multi d parameters
* device op add multi-d parameters
* add multi-d to reduce
* fix kbach=1 bug
* change B layout to col in bf16Ai8B example
* remove float adding struct
* change multi-d interface
* change file and class name
* remove multi-d of bf16Ai8B example
* change IsReduce function to IsReduceAdd
* change example layout to RRR from RCR
* according layout to set ds stride
* reset parameter layout
* add gemm universal reduce instance
* add reduce factory
* add profile_gemm_universal_reduce
* add reduce to profiler
* fix reduce instance
* fix profiler reduce compiling bug
* format
* format library instance code
* add mem instance for reduce library
* fix call instance names
* add workspace for reduce in ckProfiler
* format
* add mnpading to reduce library instance
* add fp16 instance to reduce of profiler
* change copyright time
* restore profiler cmake file
* add reduce text to instances
* add DsLayout and DsDataType to instances template parameter
* fixed gemm_reduce_multi_d
* add an example without multi_d
* Update common.hpp
* Update gtest.cmake
* Update gemm_xdl_splitk_reduce_bf16.cpp
* clean
* Update gtest.cmake
* format
* fixe api
* format
* default parameter change to RRR
* add vector_len for multi_d
* format
* Update gtest.cmake
* fix bf16A iBB elementwiseop
* add ReduceDataType
* move ReduceDataType to end position
* format
* remove googletest git method address
* fix copyright time
* update init data
---------
Co-authored-by: root <jizhan@amd.com>
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: Jing Zhang <jizhan@meta.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: c544eb4da0]
This commit is contained in:
@@ -38,6 +38,41 @@ struct DeviceGemmV2 : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2R1 : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> DsStrides,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceReduceMultiD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
|
||||
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
|
||||
const std::array<index_t, NumOutDim> outLengths,
|
||||
const std::array<index_t, NumOutDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const void* in_dev,
|
||||
const std::array<const void*, NumDTensor> ds_dev,
|
||||
void* out_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation out_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,703 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ReduceDataType = CDataType,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
ReduceDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
const std::array<const void*, NumDTensor> p_ds_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t k_batch_)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
reinterpret_cast<ReduceDataType*>(p_c_grid_),
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
true),
|
||||
p_ds(p_ds_),
|
||||
StrideDs(StrideDs_)
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<const void*, NumDTensor> p_ds;
|
||||
std::array<ck::index_t, NumDTensor> StrideDs;
|
||||
};
|
||||
|
||||
using ReduceAdd = ck::reduce::Add;
|
||||
using OutElementwiseOperation = CElementwiseOperation;
|
||||
|
||||
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same<CLayout, DLayout>::value)
|
||||
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
|
||||
else
|
||||
return Number<1>{};
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceThreadWiseMultiD<
|
||||
ReduceDataType, // InDataType,
|
||||
DsDataType, // DsDatatype
|
||||
GemmAccDataType, // AccDataType,
|
||||
CDataType, // OutDataType,
|
||||
3, // Rank
|
||||
1, // NumReduceDim
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
OutElementwiseOperation,
|
||||
256, // BlockSize_,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
|
||||
1, // KThreadSliceSize_,
|
||||
0, // InSrcVectorDim_,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
|
||||
decltype(DsVectorLengthSequence)>;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
static constexpr index_t NumInDim = 3;
|
||||
static constexpr index_t NumOutDim = 2;
|
||||
|
||||
std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
|
||||
std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
|
||||
|
||||
std::array<ck::index_t, NumInDim> in_strides;
|
||||
std::array<ck::index_t, NumOutDim> out_strides;
|
||||
if constexpr(std::is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
in_strides = {arg.M * arg.N, arg.N, 1};
|
||||
out_strides = {arg.N, 1};
|
||||
}
|
||||
else
|
||||
{
|
||||
in_strides = {arg.M * arg.N, 1, arg.M};
|
||||
out_strides = {1, arg.M};
|
||||
}
|
||||
|
||||
std::array<int, 1> reduce_dims{0};
|
||||
|
||||
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
|
||||
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
DsLengths[i] = out_lengths;
|
||||
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
DsStrides[i] = {arg.StrideDs[i], 1};
|
||||
}
|
||||
else
|
||||
{
|
||||
DsStrides[i] = {1, arg.StrideDs[i]};
|
||||
}
|
||||
});
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
DsLengths,
|
||||
DsStrides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
reduce_dims,
|
||||
arg.p_workspace_,
|
||||
arg.p_ds,
|
||||
arg.p_c_grid,
|
||||
PassThrough{},
|
||||
OutElementwiseOperation{});
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(reduce.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
|
||||
|
||||
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
|
||||
std::is_same<CDataType, ReduceDataType>::value))
|
||||
{
|
||||
if(arg.p_workspace_ == nullptr)
|
||||
{
|
||||
throw std::runtime_error("using reduce , but empty workspace!");
|
||||
}
|
||||
|
||||
arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
|
||||
arg,
|
||||
stream_config.rotating_count,
|
||||
arg.M * arg.K * sizeof(ADataType),
|
||||
arg.K * arg.N * sizeof(BDataType));
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
|
||||
std::is_same<CDataType, ReduceDataType>::value))
|
||||
{
|
||||
// reduce c data
|
||||
ave_time += RunReduce(arg_, stream_config);
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const std::array<const void*, NumDTensor> p_ds,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversalReduce"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = *dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
|
||||
std::is_same<CDataType, ReduceDataType>::value))
|
||||
{
|
||||
std::cout << "using workspace" << std::endl;
|
||||
return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,412 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <array>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize,
|
||||
typename DsVectorSizeSequence>
|
||||
struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr index_t NumSrcDim = Rank;
|
||||
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
|
||||
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
|
||||
const std::array<index_t, Rank>& inStrides)
|
||||
{
|
||||
const auto tupleSrcLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleSrcStrides =
|
||||
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
if constexpr(reduceAllDim)
|
||||
{
|
||||
const auto one_dim_inDesc = transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(tupleSrcLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return transform_tensor_descriptor(one_dim_inDesc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
1, one_dim_inDesc.GetLength(Number<0>{})))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths = generate_tuple(
|
||||
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K =
|
||||
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (in_grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
|
||||
const std::array<index_t, NumDstDim>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
|
||||
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i],
|
||||
DsStrides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
|
||||
using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
|
||||
using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
|
||||
|
||||
using GridwiseReduce =
|
||||
GridwiseReduction_mk_to_m_threadwise_multi_d<InDataType,
|
||||
DsDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
DsGridDesc_M,
|
||||
OutGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
BlockSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
OutDstVectorSize,
|
||||
DsVectorSizeSequence>;
|
||||
|
||||
using DsGridPointer = typename GridwiseReduce::DsGridPointer;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
|
||||
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const InDataType* in_dev,
|
||||
const std::array<const void*, NumDTensor> ds_dev,
|
||||
OutDataType* out_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation out_elementwise_op)
|
||||
: DsLengths_{DsLengths},
|
||||
DsStrides_{DsStrides},
|
||||
outLengths_{outLengths},
|
||||
outStrides_{outStrides},
|
||||
in_dev_{in_dev},
|
||||
out_dev_{out_dev},
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
out_elementwise_op_{out_elementwise_op}
|
||||
{
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_total_length, reduce_total_length) =
|
||||
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
invariant_lowest_length = 1;
|
||||
else
|
||||
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
|
||||
|
||||
reduce_lowest_length = inLengths_[Rank - 1];
|
||||
|
||||
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
|
||||
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
|
||||
M_BlockTileSize;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
|
||||
});
|
||||
|
||||
ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
|
||||
}
|
||||
|
||||
std::array<index_t, Rank> inLengths_;
|
||||
std::array<index_t, Rank> inStrides_;
|
||||
|
||||
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
|
||||
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
|
||||
|
||||
std::array<index_t, NumDstDim> outLengths_;
|
||||
std::array<index_t, NumDstDim> outStrides_;
|
||||
|
||||
const InDataType* in_dev_;
|
||||
OutDataType* out_dev_;
|
||||
|
||||
DsGridPointer p_ds_grid_;
|
||||
|
||||
InElementwiseOperation in_elementwise_op_;
|
||||
OutElementwiseOperation out_elementwise_op_;
|
||||
|
||||
DsGridDesc_M ds_grid_desc_m_;
|
||||
|
||||
index_t invariant_lowest_length;
|
||||
index_t reduce_lowest_length;
|
||||
long_index_t invariant_total_length;
|
||||
long_index_t reduce_total_length;
|
||||
|
||||
int numBlockTileIteration;
|
||||
size_t gridSize;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceThreadWiseMultiD::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
const auto out_grid_desc_m =
|
||||
DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
const auto kernel = kernel_reduce_threadwise_multi_d<GridwiseReduce,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
DsGridDesc_M,
|
||||
OutGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
DsGridPointer>;
|
||||
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
arg.ds_grid_desc_m_,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.out_elementwise_op_,
|
||||
arg.in_dev_,
|
||||
arg.p_ds_grid_,
|
||||
arg.out_dev_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg->inStrides_[Rank - 1] != 1)
|
||||
return (false);
|
||||
|
||||
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
|
||||
return (false);
|
||||
};
|
||||
|
||||
// To improve
|
||||
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
|
||||
return (false);
|
||||
|
||||
std::cerr << "reduce_total_length = " << pArg->reduce_total_length
|
||||
<< " KThreadSliceSize = " << KThreadSliceSize << std::endl;
|
||||
|
||||
// cases with big reduce_total_length should be handled by Blockwise kernel
|
||||
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
|
||||
return (false);
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
|
||||
const std::array<index_t, Rank> inStrides,
|
||||
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
|
||||
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
|
||||
const std::array<index_t, NumDstDim> outLengths,
|
||||
const std::array<index_t, NumDstDim> outStrides,
|
||||
const std::array<int, NumReduceDim> reduceDims,
|
||||
const void* in_dev,
|
||||
const std::array<const void*, NumDTensor> ds_dev,
|
||||
void* out_dev,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation out_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
inStrides,
|
||||
DsLengths,
|
||||
DsStrides,
|
||||
outLengths,
|
||||
outStrides,
|
||||
reduceDims,
|
||||
static_cast<const InDataType*>(in_dev),
|
||||
ds_dev,
|
||||
static_cast<OutDataType*>(out_dev),
|
||||
in_elementwise_op,
|
||||
out_elementwise_op);
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ",";
|
||||
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,260 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename DsGridDesc_M,
|
||||
typename OutGridDesc_M,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename DsGridPointer>
|
||||
__global__ void
|
||||
kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k,
|
||||
const DsGridDesc_M ds_grid_desc_m,
|
||||
const OutGridDesc_M out_grid_desc_m,
|
||||
const InElementwiseOperation in_elementwise_op,
|
||||
const OutElementwiseOperation out_elementwise_op,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const DsGridPointer p_ds_value_global,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
GridwiseReduction::Run(in_grid_desc_m_k,
|
||||
ds_grid_desc_m,
|
||||
out_grid_desc_m,
|
||||
in_elementwise_op,
|
||||
out_elementwise_op,
|
||||
p_in_value_global,
|
||||
p_ds_value_global,
|
||||
p_out_value_global);
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename DsDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InGridDesc_M_K,
|
||||
typename DsGridDesc_M,
|
||||
typename OutGridDesc_M,
|
||||
typename ReduceOperation,
|
||||
typename InElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
InMemoryDataOperationEnum OutMemoryDataOperation,
|
||||
index_t BlockSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t InSrcVectorDim,
|
||||
index_t InSrcVectorSize,
|
||||
index_t OutDstVectorSize,
|
||||
typename DsVectorSize>
|
||||
struct GridwiseReduction_mk_to_m_threadwise_multi_d
|
||||
{
|
||||
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
|
||||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
|
||||
(MThreadSliceSize % OutDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
|
||||
const DsGridDesc_M& ds_grid_desc_m,
|
||||
const OutGridDesc_M& out_grid_desc_m,
|
||||
const InElementwiseOperation& in_elementwise_op,
|
||||
const OutElementwiseOperation& out_elementwise_op,
|
||||
const InDataType* const __restrict__ p_in_value_global,
|
||||
const DsGridPointer p_ds_grid,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
ReduceOperation,
|
||||
false>;
|
||||
|
||||
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
ReduceOperation::template GetIdentityValue<InDataType>());
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
in_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
|
||||
|
||||
auto threadwise_src_val_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<InDataType,
|
||||
AccDataType,
|
||||
InGridDesc_M_K,
|
||||
decltype(thread_buffer_desc),
|
||||
ThreadBufferLengths,
|
||||
ThreadBufferDimAccessOrder,
|
||||
InSrcVectorDim,
|
||||
InSrcVectorSize,
|
||||
1,
|
||||
false>(
|
||||
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
|
||||
|
||||
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
|
||||
|
||||
index_t reducedLength = 0;
|
||||
do
|
||||
{
|
||||
threadwise_src_val_load.Run(in_grid_desc_m_k,
|
||||
in_global_val_buf,
|
||||
thread_buffer_desc,
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
|
||||
|
||||
reducedLength += KThreadSliceSize;
|
||||
} while(reducedLength < toReduceLength);
|
||||
|
||||
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
|
||||
|
||||
auto ds_thread_buf = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MThreadSliceSize, true>{};
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto ds_global_buf = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
auto ds_global_load = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v2<DataType,
|
||||
DataType,
|
||||
decltype(ds_grid_desc_m[I]),
|
||||
decltype(reduced_data_desc),
|
||||
Sequence<MThreadSliceSize>, // SliceLengths
|
||||
Sequence<0>, // DimAccessOrder
|
||||
InSrcVectorDim, // SrcVectorDim
|
||||
DsVectorSize{}[I],
|
||||
1, // SrcScalarStrideInVector
|
||||
true>{
|
||||
ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)};
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto I) {
|
||||
ds_global_load(I).Run(ds_grid_desc_m[I],
|
||||
ds_global_buf[I],
|
||||
reduced_data_desc,
|
||||
make_tuple(I0),
|
||||
ds_thread_buf(I));
|
||||
});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> out_value_buf;
|
||||
|
||||
// if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(accu_value_buf[I]),
|
||||
generate_tie(
|
||||
[&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs);
|
||||
});
|
||||
}
|
||||
|
||||
auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<OutDataType,
|
||||
OutDataType,
|
||||
decltype(reduced_data_desc),
|
||||
OutGridDesc_M,
|
||||
PassThrough,
|
||||
Sequence<MThreadSliceSize>,
|
||||
Sequence<0>,
|
||||
0,
|
||||
OutDstVectorSize,
|
||||
OutMemoryDataOperation,
|
||||
1,
|
||||
false>(
|
||||
out_grid_desc_m,
|
||||
make_multi_index(thread_global_1d_id * MThreadSliceSize),
|
||||
PassThrough{});
|
||||
|
||||
threadwise_dst_store.Run(
|
||||
reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -42,7 +42,7 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
#else
|
||||
@@ -73,7 +73,7 @@ __global__ void
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg);
|
||||
@@ -531,21 +531,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t k_batch_)
|
||||
index_t k_batch_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
p_c_grid{p_c_grid_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
{
|
||||
return (Problem::KBatch > 1) && is_reduce;
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsAtomicAdd() const
|
||||
{
|
||||
return (Problem::KBatch > 1) && (!is_reduce);
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -574,10 +588,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
}
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
c_reduce_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1080,16 +1104,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
if(!karg.IsReduceAdd())
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user