mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
initial stream-k implementation with example (#699)
* initial stream-k implementation with example * fix unexpected change in err * improve a little bit performance by reorganize pipeline. * improve perf a little bit by swizzle block idx * add profiler * update example * fix spelling * shrink karg for streamk * support dynamic buffer using memory coherence glc_slc bit from template * control memory coherence while construct dynamic buffer * update reduction for streamk(not ready yet) * Add template parameter to make_dynamic_buffer to support amd_buffer coherence setting * fix build issue * fix several bug * now result is correct, everything works (but has scratch) * remove scratch by manually reset coordinate * update device code * fix a bug in final reduce * fix something in example * update async memset * fix enum as camel case * modify coherence enum name * clean code and use atomic streamk by default * remove unused var * throw exception if have empty pointer * fix format * fix CI warning * fix type in init * modify CI error * filter out on gfx10+ * restore changed example code --------- Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct ThreadGroupTensorSliceTransfer_v6r1r2
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.template Run<SrcBuffer, DstBuffer, DstInMemOp>(
|
||||
src_desc, src_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r1r2<SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmStreamK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t NumSKBlocks = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmStreamKPtr = std::unique_ptr<DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,357 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
|
||||
struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk<
|
||||
BlockSize,
|
||||
BlockToCTileMap_GemmStreamK<MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock * K1,
|
||||
StreamKReductionStrategy::Atomic>,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
void Print(const Argument& karg) { karg.Print(); }
|
||||
|
||||
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
Print(karg);
|
||||
}
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
|
||||
"setting");
|
||||
}
|
||||
|
||||
dim3 grid_dims = karg.block_mapping.get_grid_dims();
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
|
||||
// TODO: remove clear buffer for streamk kernels
|
||||
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
grid_dims,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_workspace_,
|
||||
karg.M,
|
||||
karg.N,
|
||||
karg.K,
|
||||
karg.StrideA,
|
||||
karg.StrideB,
|
||||
karg.StrideC,
|
||||
karg.block_mapping);
|
||||
}
|
||||
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
|
||||
karg.block_mapping.get_workspace_size_for_acc(
|
||||
sizeof(typename GridwiseGemm::FloatAcc));
|
||||
auto preprocess = [&]() {
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(workspace_semaphore,
|
||||
0,
|
||||
karg.block_mapping.get_workspace_size_for_semaphore(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
|
||||
preprocess,
|
||||
kernel,
|
||||
grid_dims,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_workspace_,
|
||||
karg.M,
|
||||
karg.N,
|
||||
karg.K,
|
||||
karg.StrideA,
|
||||
karg.StrideB,
|
||||
karg.StrideC,
|
||||
karg.block_mapping);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
uint32_t NumSKBlocks = 0xffffffff)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
|
||||
hip_check_error(rtn);
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
NumSKBlocks};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
index_t NumSKBlocks = 0) override
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
|
||||
hip_check_error(rtn);
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
|
||||
reinterpret_cast<const BDataType*>(p_b),
|
||||
reinterpret_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
static_cast<uint32_t>(NumSKBlocks));
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include <limits>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -669,4 +671,406 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
}
|
||||
};
|
||||
|
||||
enum StreamKReductionStrategy
|
||||
{
|
||||
Atomic = 0, // sk block use atomic to do reduction
|
||||
Reduction, // let some workgroup responsible for doing the reduction operation
|
||||
};
|
||||
|
||||
template <uint32_t MPerBlock_,
|
||||
uint32_t NPerBlock_,
|
||||
uint32_t KPerBlock_,
|
||||
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM_ = 8>
|
||||
struct BlockToCTileMap_GemmStreamK
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
MDiv2 n_tiles;
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv eqav_tiles_big; // for reduction
|
||||
MDiv eqav_tiles_little; // for reduction
|
||||
|
||||
// MDiv tile_swizzle_sub_m_rem;
|
||||
//--------------------------------------
|
||||
|
||||
// prefer construct on host
|
||||
BlockToCTileMap_GemmStreamK(uint32_t m,
|
||||
uint32_t n,
|
||||
uint32_t k,
|
||||
uint32_t num_cu,
|
||||
uint32_t occupancy,
|
||||
uint32_t sk_blocks = 0xffffffff)
|
||||
{
|
||||
uint32_t num_tiles =
|
||||
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
|
||||
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
|
||||
|
||||
// one cu can hold one wg at one time, from the whole chip's point of view
|
||||
// if number of wg is same as num_cu, we call it 1 dispatch
|
||||
// if number of wg is 2x num_cu, we call it 2 dispatches.
|
||||
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
|
||||
// dispatch)
|
||||
//
|
||||
uint32_t full_dispatches = num_tiles / num_cu;
|
||||
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
|
||||
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
|
||||
|
||||
uint32_t sk_occupancy = occupancy;
|
||||
uint32_t dp_tiles = full_dispatch_tiles;
|
||||
uint32_t sk_tiles = partial_dispatche_tiles;
|
||||
|
||||
if(full_dispatches < occupancy)
|
||||
{
|
||||
// in this case, we allocate all blocks as sk blocks
|
||||
// sk_occupancy = occupancy - full_dispatches;
|
||||
sk_occupancy = 1; // TODO: single occ seems better
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
|
||||
{
|
||||
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
|
||||
// occupancy = 3, full_dispatches = 5, 8, 11 ...
|
||||
// occupancy = 4, full_dispatches = 7, 11 ...
|
||||
sk_occupancy = 1; // left 1 slot for sk occupancy
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatche_tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// others, we reduce 1 dispatch from dp, together with partial dispatch,
|
||||
// to construct sk dispatch
|
||||
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
|
||||
dp_tiles = full_dispatch_tiles - num_cu;
|
||||
sk_tiles = partial_dispatche_tiles + num_cu;
|
||||
}
|
||||
|
||||
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
uint32_t dp_num_blocks = 0;
|
||||
|
||||
{
|
||||
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
|
||||
uint32_t max_sk_tiles =
|
||||
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
|
||||
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
|
||||
|
||||
// if use dp for sk-block, how many iters do we need
|
||||
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
uint32_t tentative_sk_iters_per_block =
|
||||
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
|
||||
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
|
||||
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
|
||||
|
||||
// TODO: carefully adjust this parameter
|
||||
// the more sk_blocks_per_tile, the worse the overhead
|
||||
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
|
||||
if(tentative_sk_blocks % sk_tiles != 0)
|
||||
{
|
||||
// penalty for uneven divide
|
||||
cross_sk_blocks_overhead +=
|
||||
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
|
||||
}
|
||||
|
||||
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
|
||||
|
||||
if(tentative_sk_score < best_sk_score)
|
||||
{
|
||||
best_sk_score = tentative_sk_score;
|
||||
sk_num_blocks = tentative_sk_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_sk_score >= dp_for_sk_iters)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
}
|
||||
|
||||
// give a chance to control num of sk blocks
|
||||
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
|
||||
|
||||
if(sk_num_blocks == 0)
|
||||
{
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
else
|
||||
{
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
|
||||
}
|
||||
}
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
|
||||
"sk_num_blocks:%d, "
|
||||
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
|
||||
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
|
||||
"sk_tiles:%u, workspace(acc float):%u\n",
|
||||
num_cu,
|
||||
occupancy,
|
||||
get_grid_dims().x,
|
||||
num_tiles,
|
||||
dp_tiles,
|
||||
sk_num_big_blocks,
|
||||
sk_num_blocks,
|
||||
sk_total_iters,
|
||||
dp_start_block_idx,
|
||||
dp_iters_per_block,
|
||||
dp_num_blocks,
|
||||
k_iters_per_tile.get(),
|
||||
k_iters_per_big_block,
|
||||
reduction_start_block_idx,
|
||||
get_sk_tiles(),
|
||||
get_workspace_size(sizeof(float)));
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_total_iters() const
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_tiles() const
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
__host__ __device__ dim3 get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
}
|
||||
else
|
||||
return dim3(reduction_start_block_idx, 1, 1);
|
||||
}
|
||||
|
||||
__device__ uint32_t get_block_idx() const
|
||||
{
|
||||
// TODO: swizzle block index for better locality
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
}
|
||||
|
||||
__device__ void
|
||||
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t current_iter_length = math::min(
|
||||
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
|
||||
return current_iter_length;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
|
||||
|
||||
__device__ void
|
||||
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
|
||||
{
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
|
||||
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
|
||||
|
||||
// swizzle tile
|
||||
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
|
||||
? tile_swizzle_sub_m
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
|
||||
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
|
||||
{
|
||||
return get_sk_tiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
|
||||
{
|
||||
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
|
||||
const MDiv& eqav_tiles_) const
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_eqav_tiles_ + rem_;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_total_acc_buffers() const
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big =
|
||||
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
|
||||
{
|
||||
// TODO: from big to little
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
|
||||
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
|
||||
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
struct GridwiseGemmPipeline_v3
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t)
|
||||
{
|
||||
// TODO: improve applicability
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// LDS write 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
num_loop--;
|
||||
|
||||
while(num_loop > 0)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
block_sync_lds();
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
num_loop--;
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
bool DstResetCoordinateAfterRun>
|
||||
struct ThreadwiseTensorSliceTransfer_v6r1r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
element_op_(element_op)
|
||||
{
|
||||
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
// loop over space-filling curve
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
auto dst_vector_container = dst_vector_type{};
|
||||
|
||||
// apply pointwise operation
|
||||
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// move coordinate
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCoordinateResetStep()
|
||||
{
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
|
||||
? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx = DstResetCoordinateAfterRun
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user