mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Navi3 rel (#1176)
* wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu <samjwu@users.noreply.github.com> * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise data transfer with dequantization
|
||||
*
|
||||
* RunRead would load low-precision data and scale data.
|
||||
* RunWrite would process dequantization process.
|
||||
* Assume Scale is identical along K-dimension
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename SrcElementwiseOperation,
|
||||
typename ScaleElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename BlockScaleSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename ScaleData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename ScaleDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t ScaleScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t ScaleScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_dequant
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto scale_thread_slice_lengths =
|
||||
BlockScaleSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const ScaleDesc& scale_desc,
|
||||
const Index& scale_block_slice_origin,
|
||||
const ScaleElementwiseOperation& scale_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
scale_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
scale_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<ScaleDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
|
||||
is_same<BlockScaleSliceLengths,
|
||||
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetScaleSliceOrigin(
|
||||
scale_desc, scale_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
// With the assumption, scale scratch is always one
|
||||
template <typename ScaleBuffer>
|
||||
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
// We don't prefer use this API directly
|
||||
/*
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
*/
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// With the assumption, scale buffer don't need move slice window method
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
|
||||
decltype(scale_thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
ScaleElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
ScaleData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
ScaleDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
ScaleScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
ScaleScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline
|
||||
// As input tensor thread buffer declared inside blockwise-gemm pipeline.
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_dequantB : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -62,10 +62,10 @@ template <index_t NumDimG,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
@@ -73,13 +73,14 @@ template <index_t NumDimG,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization DESpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -100,7 +101,6 @@ template <index_t NumDimG,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
@@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
|
||||
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
|
||||
@@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
|
||||
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
|
||||
@@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
@@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
|
||||
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
|
||||
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_n_k_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
|
||||
e_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
@@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
});
|
||||
|
||||
a_grid_desc_m_k_ =
|
||||
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_n_k_ =
|
||||
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
|
||||
ds_grid_desc_m_n_ =
|
||||
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
|
||||
@@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
e_grid_desc_m_n_ =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
@@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
// Batch Offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// for checking vector load/store
|
||||
// index_t MRaw_;
|
||||
// index_t NRaw_;
|
||||
// index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
@@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
G,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
@@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp: Arch check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
printf("GridwiseOp: Validity check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-m check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access D-n check failure\n");
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
@@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
0) ||
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access E-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,714 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
|
||||
// 2. C(M, N) = A(M, K) * DequantB(K, N)
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ScaleDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::weight_only>
|
||||
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
// LDS bypass feature not implemented for dequantization pipeline.
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle;
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
|
||||
{
|
||||
// assume Scale is [1, N]
|
||||
const auto scale_grid_desc_n_k = [&]() {
|
||||
const auto scale_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
|
||||
}();
|
||||
|
||||
const auto N = scale_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = scale_grid_desc_n_k.GetLength(I1);
|
||||
// When K = 1, it might be scale tensor.
|
||||
assert(K % K1 == 0 && K != 1);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseFpAintBGemm_Wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
ScaleGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const ScaleDataType* p_scale_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_scale_grid_{p_scale_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
scale_grid_desc_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
MRaw_{M},
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0);
|
||||
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const ScaleDataType* p_scale_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
ScaleGridDesc scale_grid_desc_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_fpAintB_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc>,
|
||||
remove_reference_t<DeviceOp::BGridDesc>,
|
||||
remove_reference_t<DeviceOp::ScaleGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_scale_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.scale_grid_desc_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of C
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const ScaleDataType* p_scale,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_scale,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{
|
||||
{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"},
|
||||
{PipelineVersion::weight_only, "weight_only"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceFpAintBGemm_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -27,21 +28,22 @@ template <typename ALayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -62,7 +64,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
@@ -83,68 +84,139 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELayout_>
|
||||
@@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -207,9 +279,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -222,6 +294,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -230,6 +303,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -262,8 +336,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
a_grid_desc{},
|
||||
b_grid_desc{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc,
|
||||
b_grid_desc,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
@@ -318,8 +392,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc;
|
||||
BGridDesc b_grid_desc;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -352,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -381,91 +439,64 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
|
||||
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
float ave_time = 0;
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>; // Last Option is W/O
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -681,14 +712,18 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -33,13 +34,14 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -60,7 +62,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
@@ -76,68 +77,138 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
@@ -159,56 +230,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_Wmma<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -230,7 +303,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -244,19 +317,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -268,8 +337,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -292,23 +361,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -320,79 +373,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
float ave_time = 0;
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -413,13 +445,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -485,7 +520,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -581,14 +616,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using CShuffleDataType = AccDataType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
// InMemory Data Descriptor
|
||||
@@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -52,22 +52,23 @@ template <index_t NDimSpatial,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
@@ -88,7 +89,6 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
@@ -109,11 +109,31 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds =
|
||||
AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
@@ -122,17 +142,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
@@ -149,13 +168,44 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
@@ -164,7 +214,39 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
|
||||
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
@@ -197,53 +279,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK1 = K1;
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK1 = K1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -252,8 +295,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -264,9 +307,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -279,6 +322,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
AEnableLds,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -287,6 +331,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BEnableLds,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -327,23 +372,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_{
|
||||
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
|
||||
@@ -395,8 +438,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
@@ -411,14 +454,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -465,8 +506,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -480,8 +530,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
@@ -501,8 +551,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
@@ -670,8 +720,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
@@ -790,9 +840,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
<< KPerBlock << ", "
|
||||
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "ABlockTransferSrcScalarPerVector: "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector
|
||||
<< ">";
|
||||
<< "BBlockTransferSrcScalarPerVector: "
|
||||
<< BBlockTransferSrcScalarPerVector;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
|
||||
template <typename MaskOutPredicate>
|
||||
struct C0MatrixMask_impl
|
||||
{
|
||||
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
|
||||
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
|
||||
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
|
||||
@@ -123,6 +123,12 @@ struct PassThrough
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
@@ -663,6 +669,76 @@ struct Elu
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
struct FastNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
|
||||
{
|
||||
using InputArray = vector_type<uint8_t, 4>;
|
||||
using OutputArray = vector_type<ck::half_t, 4>;
|
||||
|
||||
__device__ static OutputArray convert(InputArray const& Input)
|
||||
{
|
||||
OutputArray Output;
|
||||
|
||||
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
|
||||
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
|
||||
|
||||
static constexpr uint32_t byte_selector_01 = 0x05010500;
|
||||
static constexpr uint32_t byte_selector_23 = 0x05030502;
|
||||
static constexpr uint32_t fp16_adder = 0x64646464;
|
||||
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
|
||||
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
|
||||
: "=v"(half_2[0])
|
||||
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
|
||||
: "=v"(half_2[1])
|
||||
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
|
||||
|
||||
return Output;
|
||||
}
|
||||
|
||||
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using InputArray = vector_type<uint8_t, N>;
|
||||
using OutputArray = vector_type<ck::half_t, N>;
|
||||
|
||||
__device__ static OutputArray convert(InputArray const& Input)
|
||||
{
|
||||
FastNumericArrayConverter<uint8_t, ck::half_t, 4> converter;
|
||||
|
||||
OutputArray Output;
|
||||
|
||||
using Vec_InputArray = vector_type<uint8_t, 4>;
|
||||
using Vec_OutputArray = vector_type<ck::half_t, 4>;
|
||||
|
||||
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
|
||||
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
|
||||
|
||||
static_for<0, N / VEC_WIDTH, 1>{}(
|
||||
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
|
||||
|
||||
return Output;
|
||||
}
|
||||
|
||||
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage, true, true>;
|
||||
|
||||
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
|
||||
static constexpr auto MakeD0sGridPointer()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,18 +17,21 @@ enum struct PipelineVersion
|
||||
v2,
|
||||
// v3 is only used in the Stream-K implementation.
|
||||
v4,
|
||||
weight_only,
|
||||
};
|
||||
|
||||
template <PipelineVersion PipelineVer,
|
||||
index_t NumPrefetch = 1,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
LoopScheduler LoopSched = LoopScheduler::Default,
|
||||
bool AEnableLds = true,
|
||||
bool BEnableLds = true>
|
||||
constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(PipelineVer == PipelineVersion::v1)
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
return GridwiseGemmPipeline_v4<NumPrefetch>{};
|
||||
}
|
||||
else if constexpr(PipelineVer == PipelineVersion::weight_only)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1>
|
||||
struct GridwiseGemmPipeline_v1<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
|
||||
|
||||
// 2-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, true, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly;
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename ScaleGridDesc,
|
||||
typename ScaleGridBuffer,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const ScaleGridDesc& scale_grid_desc,
|
||||
const ScaleGridBuffer& scale_grid_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Global Prefetch Stage 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Scale read once
|
||||
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// Dequantization fused in blockwise_copy
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
struct GridwiseGemmPipelineInterwave_v1;
|
||||
|
||||
@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
|
||||
|
||||
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
|
||||
template <>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
|
||||
@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage, true, true>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
@@ -33,31 +33,27 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_wmma(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
// const
|
||||
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
|
||||
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
const AGridDesc a_grid_desc,
|
||||
const BGridDesc b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
a_grid_desc,
|
||||
b_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -67,8 +63,8 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = a_grid_desc;
|
||||
ignore = b_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
@@ -78,21 +74,21 @@ __global__ void
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t K1Value,
|
||||
@@ -105,6 +101,7 @@ template <index_t BlockSize,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool AEnableLds,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -113,6 +110,7 @@ template <index_t BlockSize,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BEnableLds,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
@@ -121,7 +119,7 @@ template <index_t BlockSize,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
struct GridwiseGemm_Wmma
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
// FIX ME: To be deprecated
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe =
|
||||
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
AEnableLds,
|
||||
BEnableLds>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
// Describe how data store to (LDS/VGPR) buffer from Global memory
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
constexpr auto a_block_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->M->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0perblock_mperblock_k1;
|
||||
return a_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
constexpr auto b_block_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->N->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0perblock_nperblock_k1;
|
||||
return b_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto a_block_copy_step = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_copy_step;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto b_block_copy_step = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_copy_step;
|
||||
}
|
||||
|
||||
// Describe how data read from (LDS/VGPR) buffer
|
||||
template <typename ABlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
|
||||
{
|
||||
|
||||
constexpr auto a_wave_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
|
||||
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Err: merge transform cause non-constexpr issue
|
||||
|
||||
// return transform_tensor_descriptor(
|
||||
// ABlockDesc_{},
|
||||
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
|
||||
// make_pass_through_transform(Number<MRepeat>{}),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(Number<A_K1>{})),
|
||||
// make_tuple(Sequence<0, 3>{},
|
||||
// Sequence<1>{},
|
||||
// Sequence<2>{},
|
||||
// Sequence<4>{},
|
||||
// Sequence<5>{}),
|
||||
// make_tuple(
|
||||
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
|
||||
// Sequence<4>{}));
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<A_KRow>{},
|
||||
I1,
|
||||
Number<A_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_wave_desc;
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
|
||||
{
|
||||
constexpr auto b_wave_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
|
||||
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<B_KRow>{},
|
||||
I1,
|
||||
Number<B_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_wave_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
|
||||
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
|
||||
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned * sizeof(FloatA) +
|
||||
b_block_space_size_aligned * sizeof(FloatB));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
(NPerBlock % (NRepeat * NPerWmma)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
const auto GetAProblemsizeMK = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
|
||||
a_grid_desc.GetLength(I5),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
|
||||
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto GetBProblemsizeNK = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
|
||||
b_grid_desc.GetLength(I5),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
|
||||
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto M = GetAProblemsizeMK()[I0];
|
||||
const auto N = GetBProblemsizeNK()[I0];
|
||||
const auto K = GetAProblemsizeMK()[I1];
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
K == GetBProblemsizeNK()[I1]))
|
||||
{
|
||||
printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
|
||||
GetAProblemsizeMK()[I0],
|
||||
GetAProblemsizeMK()[I1],
|
||||
GetBProblemsizeNK()[I0],
|
||||
GetBProblemsizeNK()[I1],
|
||||
c_grid_desc_m_n.GetLength(I0),
|
||||
c_grid_desc_m_n.GetLength(I1));
|
||||
printf("GridwiseOp err: ProblemSize check");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
printf("GridwiseOp err: ProblemSize division");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
printf("GridwiseOp err: Pipeline not support this k_loop");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
|
||||
if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
static constexpr auto max_lds_align = K1;
|
||||
|
||||
static constexpr auto a_block_space_size_aligned =
|
||||
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
static constexpr auto b_block_space_size_aligned =
|
||||
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
|
||||
static constexpr auto a_block_space_offset = 0;
|
||||
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
static constexpr auto c_shuffle_block_space_size =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
.GetElementSpaceSize();
|
||||
|
||||
static constexpr auto c_shuffle_block_space_offset = 0;
|
||||
|
||||
static constexpr auto lds_size =
|
||||
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
|
||||
a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
constexpr auto max_lds_align = K1;
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
|
||||
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
|
||||
const auto K = [&](){
|
||||
if constexpr(AEnableLds){
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else{
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
|
||||
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_block_desc = MakeABlockDescriptor();
|
||||
constexpr auto b_block_desc = MakeBBlockDescriptor();
|
||||
|
||||
auto a_block_trait = [&](){
|
||||
// A matrix blockwise copy
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared),
|
||||
SharedMemTrait::a_block_space_size_aligned);
|
||||
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
|
||||
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
|
||||
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
|
||||
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
|
||||
/* typename SrcData, */ FloatA,
|
||||
/* typename DstData, */ FloatA,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
|
||||
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
/* typename SrcData, */ ADataType,
|
||||
/* typename DstData, */ ADataType,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc),
|
||||
/* typename DstDesc, */ decltype(a_block_desc),
|
||||
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
|
||||
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
|
||||
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
|
||||
@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/* index_t SrcScalarStrideInVector, */ 1,
|
||||
/* index_t DstScalarStrideInVector, */ 1,
|
||||
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
a_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
|
||||
a_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto a_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc),
|
||||
decltype(a_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0,
|
||||
m_block_data_idx_on_grid/(MWaves * MPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto b_block_trait = [&](){
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
|
||||
SharedMemTrait::b_block_space_size_aligned);
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0,
|
||||
n_block_data_idx_on_grid/(NWaves * NPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto a_block_buf = a_block_trait()[I0];
|
||||
auto a_blockwise_copy = a_block_trait()[I1];
|
||||
|
||||
auto b_block_buf = b_block_trait()[I0];
|
||||
auto b_blockwise_copy = b_block_trait()[I1];
|
||||
/*******************************************************************************/
|
||||
// GEMM
|
||||
constexpr auto WmmaK = 16;
|
||||
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
BlockwiseGemmWMMA<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
decltype(MakeAWaveDescriptor(a_block_desc)),
|
||||
decltype(MakeBWaveDescriptor(b_block_desc)),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
AEnableLds,
|
||||
BEnableLds>{};
|
||||
|
||||
// Prepare Register for C matrix
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
/*******************************************************************************/
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
|
||||
|
||||
/*******************************************************************************/
|
||||
// Shift Per SUB_K
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
|
||||
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
KBlockMainLoop);
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
|
||||
SharedMemTrait::c_shuffle_block_space_size);
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatCShuffle,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
@@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
// Specilized for WMMA
|
||||
// A single Wave32 is composed by double row
|
||||
// Data exchange allowed between these two rows
|
||||
// This RowLane Dst buf will be filled from two Src buf
|
||||
// SrcA: From specific thread buffer hold by This RowLane on This Row
|
||||
// SrcB: From specific thread buffer hold by This RowLane on The other Row
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
uint32_t LowEightRowlaneIdx,
|
||||
uint32_t HighEightRowLaneIdx,
|
||||
bool IntraRowSwizzlePerm,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
ignore = src_idx;
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! SliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
|
||||
"wrong! Buffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v_this_row, v_theother_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
|
||||
// apply inter-row permute.
|
||||
temp = __builtin_amdgcn_permlanex16(temp,
|
||||
type_convert_sp<int>(v_this_row),
|
||||
LowEightRowlaneIdx,
|
||||
HighEightRowLaneIdx,
|
||||
1,
|
||||
0);
|
||||
v_theother_row = type_convert_sp<SrcData>(temp);
|
||||
|
||||
if(get_thread_local_1d_id() % 32 < 16)
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
ElementwiseOperation element_op_{};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
WaveSize,
|
||||
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t Opsel,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
WaveSize,
|
||||
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
WaveSize,
|
||||
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -346,7 +342,7 @@ struct WmmaSelector
|
||||
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
|
||||
selected_wmma.acc_data_size ==
|
||||
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
|
||||
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
|
||||
"WRONG! Invalid Number of Accumulator Register");
|
||||
}
|
||||
@@ -358,7 +354,8 @@ template <typename src_type_a,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
bool TransposeC = false,
|
||||
bool AssemblyBackend = false>
|
||||
struct WmmaGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -369,14 +366,14 @@ struct WmmaGemm
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
using CIndex3D = MultiIndex<3>;
|
||||
|
||||
__host__ __device__ constexpr WmmaGemm()
|
||||
{
|
||||
static_assert(NPerWmma == 16 && MPerWmma == 16,
|
||||
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
|
||||
|
||||
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
|
||||
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
|
||||
}
|
||||
|
||||
// WMMA output supporting C = A * B
|
||||
@@ -421,9 +418,49 @@ struct WmmaGemm
|
||||
Sequence<5>{}));
|
||||
}
|
||||
|
||||
// Transposed WMMA Output C' = B' * A'
|
||||
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
|
||||
{
|
||||
const auto MBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
|
||||
const auto NBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
|
||||
const auto MWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
|
||||
const auto NWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
|
||||
make_tuple(
|
||||
make_pass_through_transform(MBlockxRepeat),
|
||||
make_pass_through_transform(MWave),
|
||||
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
|
||||
make_pass_through_transform(NBlockxRepeat),
|
||||
make_pass_through_transform(NWave),
|
||||
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerWmma()
|
||||
{
|
||||
return wmma_instr.num_acc_vgprs_per_wave;
|
||||
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
@@ -449,14 +486,16 @@ struct WmmaGemm
|
||||
,
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"(int8, int32) or (int4, int32)!");
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
|
||||
}
|
||||
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
|
||||
@@ -477,12 +516,12 @@ struct WmmaGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
return GetSwizzledLaneIdLow();
|
||||
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
return GetLaneIdUnderSubGroup();
|
||||
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
@@ -493,6 +532,14 @@ struct WmmaGemm
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ static CIndex3D GetBeginOfThreadBlk3D()
|
||||
{
|
||||
index_t n_offset = GetLaneIdUnderSubGroup();
|
||||
index_t m_offset = GetSubGroupId();
|
||||
|
||||
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto wmma_instr = wmma.selected_wmma;
|
||||
@@ -500,7 +547,10 @@ struct WmmaGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
|
||||
{
|
||||
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
|
||||
return make_tuple(I1,
|
||||
I1,
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{},
|
||||
Number<wmma_instr.acc_pack_number>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user