mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge remote-tracking branch 'origin/develop' into myamlak/cgemm
This commit is contained in:
@@ -109,6 +109,10 @@
|
||||
// experimental feature: use __builtin_memcpy instead of union to do bit_cast
|
||||
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
|
||||
|
||||
// experimental feature: optimize for inter-wave scheduling policy
|
||||
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0
|
||||
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
// thread-invariant, otherwise it's a bug
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// "_PACKAGE_" to avoid name contentions: the macros like
|
||||
// HIP_VERSION_MAJOR are defined in HIP_VERSION.h.
|
||||
// clang-format off
|
||||
#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@
|
||||
#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@
|
||||
#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@
|
||||
// clang-format on
|
||||
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_MAJOR
|
||||
#define CK_HIP_PACKAGE_VERSION_MAJOR 0
|
||||
#endif
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_MINOR
|
||||
#define CK_HIP_PACKAGE_VERSION_MINOR 0
|
||||
#endif
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_PATCH
|
||||
#define CK_HIP_PACKAGE_VERSION_PATCH 0
|
||||
#endif
|
||||
// 3 decimal digits for major and minor, 6 digits for patch number.
|
||||
// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math.
|
||||
#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \
|
||||
CK_HIP_PACKAGE_VERSION_PATCH > 999999
|
||||
#error "Too big HIP version number(s)"
|
||||
#endif
|
||||
#define CK_HIP_PACKAGE_VERSION_FLAT \
|
||||
((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \
|
||||
CK_HIP_PACKAGE_VERSION_PATCH)
|
||||
3
include/ck/options.hpp.in
Normal file
3
include/ck/options.hpp.in
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#cmakedefine01 CK_TIME_KERNEL
|
||||
10
include/ck/stream_config.hpp
Normal file
10
include/ck/stream_config.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
struct StreamConfig
|
||||
{
|
||||
hipStream_t stream_id_ = nullptr;
|
||||
bool time_kernel_ = false;
|
||||
};
|
||||
@@ -7,6 +7,21 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
constexpr LoopScheduler make_default_loop_scheduler()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
return LoopScheduler::Interwave;
|
||||
#else
|
||||
return LoopScheduler::Default;
|
||||
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
@@ -302,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
@@ -339,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
|
||||
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
|
||||
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
|
||||
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
|
||||
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::A_K1;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::B_K1;
|
||||
using Base::c_thread_buf_;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::CalculateAThreadOriginDataIndex;
|
||||
using Base::CalculateBThreadOriginDataIndex;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KPerThread;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
|
||||
// 2-wave optimized blockwise gemm
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, k),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, k),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
|
||||
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
|
||||
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
|
||||
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
|
||||
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
|
||||
// sync point.
|
||||
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
|
||||
{
|
||||
asm volatile("s_barrier" ::);
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<FloatAB, KPack> a_thread_vec;
|
||||
vector_type<FloatAB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, 0, 0, k_ + i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, 0, 0, k_ + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from blockwise_gemm is
|
||||
// moved here B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts It is performed near the end of MAC cluster to
|
||||
// minimize lgkmcnt penalty
|
||||
if constexpr(k.value == KPerThread - KPerInnerLoop &&
|
||||
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
|
||||
n0.value == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
}
|
||||
|
||||
// TODO: insert setprio in more precise manner since we
|
||||
// could have more than >1 MFMA instructions in single call
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier();
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerInnerLoop]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerInnerLoop]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
|
||||
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
LoopScheduler LoopSched>
|
||||
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#ifndef DEVICE_BASE_HPP
|
||||
#define DEVICE_BASE_HPP
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "stream_config.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -22,7 +23,10 @@ struct BaseInvoker
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker& operator=(const BaseInvoker&) = default;
|
||||
|
||||
virtual float Run(const BaseArgument*, int = 1) = 0;
|
||||
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
|
||||
{
|
||||
return float{0};
|
||||
}
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
@@ -33,8 +37,8 @@ struct BaseOperator
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
|
||||
virtual std::string GetTypeString() const = 0;
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
@@ -42,4 +46,3 @@ struct BaseOperator
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -106,6 +106,9 @@ __global__ void
|
||||
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
// version currently has compiler issues with register spill which further causes validation
|
||||
// failures.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -154,7 +157,8 @@ template <typename ALayout,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -600,7 +604,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
@@ -688,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int /* nrepeat */ = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -724,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float elapsed_time = 0.0f;
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
|
||||
@@ -743,26 +749,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -783,35 +791,38 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl
|
||||
{
|
||||
using Argument = DeviceBatchedGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -55,7 +55,8 @@ template <typename ALayout,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
@@ -376,7 +377,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>;
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -448,7 +450,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
|
||||
@@ -478,146 +480,77 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
if(nrepeat == 0)
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = c_real - aux needed here!!!
|
||||
// c_real = c_real - aux needed here!!!
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// // c_real = c_real - aux needed here!!!
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -634,155 +567,87 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
if(nrepeat == 0)
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// // c_real = c_real - aux needed here!!!
|
||||
// // c_real = c_real - aux needed here!!!
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = c_real - aux needed here!!!
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
// c_imag = c_imag + aux needed here!!!
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
ShowInfo(arg);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
@@ -437,49 +438,27 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(nrepeat > 0)
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
if(kbatch > 1 || nrepeat <= 0)
|
||||
{
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ struct
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -727,8 +727,8 @@ struct
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -771,8 +771,8 @@ struct
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -795,9 +795,10 @@ struct
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "common_header.hpp"
|
||||
@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto naive_conv3d_fwd =
|
||||
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
|
||||
@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>;
|
||||
|
||||
float ave_time = launch_and_time_kernel(naive_conv3d_fwd,
|
||||
nrepeat,
|
||||
float ave_time = launch_and_time_kernel(stream_config,
|
||||
naive_conv3d_fwd,
|
||||
dim3(256),
|
||||
dim3(256),
|
||||
0,
|
||||
@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
|
||||
@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1241,7 +1241,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
@@ -1316,8 +1316,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
true>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -1349,8 +1349,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
false>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -1369,9 +1369,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -747,7 +747,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -795,8 +795,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -826,8 +826,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -846,9 +846,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
// version currently has compiler issues with register spill which further causes validation
|
||||
// failures.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -62,7 +65,8 @@ template <typename ALayout,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
@@ -422,7 +426,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -498,7 +503,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int /* nrepeat */ = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -531,6 +536,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float elapsed_time = 0.0f;
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1<
|
||||
@@ -549,24 +555,26 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -586,33 +594,36 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ struct DeviceGemmXdl
|
||||
{
|
||||
using Argument = DeviceGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -339,8 +339,8 @@ struct DeviceGemmXdl
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -370,8 +370,8 @@ struct DeviceGemmXdl
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -391,9 +391,10 @@ struct DeviceGemmXdl
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
{
|
||||
using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -320,8 +320,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -359,8 +359,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -382,9 +382,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -329,8 +329,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -368,8 +368,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -391,9 +391,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -374,8 +374,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -418,8 +418,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -443,9 +443,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
// version currently has compiler issues with register spill which further causes validation
|
||||
// failures.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -54,7 +57,8 @@ template <typename ALayout,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemm_Xdl_CShuffle
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
@@ -375,7 +379,8 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>;
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -435,7 +440,7 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -482,42 +487,22 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
if(nrepeat == 0)
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -533,52 +518,32 @@ struct DeviceGemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
if(nrepeat == 0)
|
||||
{
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -385,8 +385,11 @@ struct DeviceGemmXdlSplitK
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
ShowInfo(arg);
|
||||
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(nrepeat > 0)
|
||||
{
|
||||
ShowInfo(arg);
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
// FIXME: this should be moved outside of DeviceOp
|
||||
hipGetErrorString(
|
||||
hipMemset(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
if(kbatch > 1 || nrepeat <= 0)
|
||||
{
|
||||
hipGetErrorString(
|
||||
hipMemset(arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -391,8 +391,11 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
ShowInfo(arg);
|
||||
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
@@ -414,51 +417,29 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(nrepeat > 0)
|
||||
{
|
||||
ShowInfo(arg);
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
if(kbatch > 1 || nrepeat <= 0)
|
||||
{
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
@@ -542,9 +523,10 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -449,7 +449,7 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
using Argument = DeviceGroupedGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
|
||||
|
||||
@@ -510,8 +510,8 @@ struct DeviceGroupedGemmXdl
|
||||
true,
|
||||
MaxGroupCount>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -534,8 +534,8 @@ struct DeviceGroupedGemmXdl
|
||||
false,
|
||||
MaxGroupCount>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -550,9 +550,10 @@ struct DeviceGroupedGemmXdl
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
@@ -241,8 +241,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
|
||||
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
|
||||
|
||||
return launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -257,9 +257,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
|
||||
arg.p_out_indices_dev_);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -211,7 +211,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
@@ -253,8 +253,8 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -272,9 +272,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_);
|
||||
@@ -224,8 +224,8 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -243,10 +243,11 @@ struct DeviceReduceBlockWiseSecondCall
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -245,7 +245,7 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
@@ -275,8 +275,6 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
|
||||
float avg_time = 0;
|
||||
|
||||
KernelTimer timer;
|
||||
|
||||
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>;
|
||||
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
|
||||
InDataType,
|
||||
@@ -287,50 +285,38 @@ struct DeviceReduceMultiBlockAtomicAdd
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n",
|
||||
arg.gridSize,
|
||||
BlockSize);
|
||||
printf("Warm up\n");
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
out_grid_desc_m,
|
||||
arg.out_dev_,
|
||||
static_cast<OutDataType>(0.0f));
|
||||
|
||||
for(int i = 0; i < nrepeat + 1; i++)
|
||||
{
|
||||
if(i == 1)
|
||||
timer.Start();
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.out_dev_);
|
||||
|
||||
launch_kernel(kernel_pre,
|
||||
dim3(arg.gridSize_pre),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
out_grid_desc_m,
|
||||
arg.out_dev_,
|
||||
static_cast<OutDataType>(0.0f));
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
launch_kernel(kernel_main,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_m_k,
|
||||
out_grid_desc_m,
|
||||
arg.in_elementwise_op_,
|
||||
arg.acc_elementwise_op_,
|
||||
arg.blkGroupSize,
|
||||
arg.kBlockTileIterations,
|
||||
arg.alpha_,
|
||||
arg.in_dev_,
|
||||
arg.out_dev_);
|
||||
};
|
||||
|
||||
timer.End();
|
||||
|
||||
avg_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -273,7 +273,7 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
|
||||
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
|
||||
@@ -313,8 +313,8 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -331,10 +331,11 @@ struct DeviceReduceMultiBlockPartialReduce
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -212,7 +212,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto in_grid_desc_m_k =
|
||||
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
|
||||
@@ -254,8 +254,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
InElementwiseOperation,
|
||||
OutElementwiseOperation>;
|
||||
|
||||
avg_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
avg_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -272,10 +272,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
};
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -248,4 +249,116 @@ struct GridwiseGemmPipeline_v1<2>
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
struct GridwiseGemmPipelineInterwave_v1;
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipelineInterwave_v1<1>
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
// block_sync_lds(); // moved into blockwise_gemm
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
|
||||
template <>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch, LoopScheduler LoopSched>
|
||||
constexpr auto GridwiseGemmPipeline_v1_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
return GridwiseGemmPipelineInterwave_v1<NumPrefetch>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -134,7 +134,8 @@ template <typename FloatAB,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched>
|
||||
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -473,17 +474,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack>{};
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -502,25 +504,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
@@ -107,7 +107,8 @@ template <typename FloatAB,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -416,17 +417,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack>{};
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -445,25 +447,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user