mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Navi3 rel (#1176)
* wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commitbb5530af91. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * fixed quant example (#672) Co-authored-by: root <root@ctr-ubbsmc15.amd.com> * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu <samjwu@users.noreply.github.com> * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Sam Wu <sjwu@ualberta.ca> Co-authored-by: Sam Wu <sam.wu2@amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: root <root@ctr-ubbsmc15.amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu <samjwu@users.noreply.github.com> Co-authored-by: haocwang <Haocong.WANG@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
@@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage, true, true>;
|
||||
|
||||
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
|
||||
static constexpr auto MakeD0sGridPointer()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,18 +17,21 @@ enum struct PipelineVersion
|
||||
v2,
|
||||
// v3 is only used in the Stream-K implementation.
|
||||
v4,
|
||||
weight_only,
|
||||
};
|
||||
|
||||
template <PipelineVersion PipelineVer,
|
||||
index_t NumPrefetch = 1,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
LoopScheduler LoopSched = LoopScheduler::Default,
|
||||
bool AEnableLds = true,
|
||||
bool BEnableLds = true>
|
||||
constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(PipelineVer == PipelineVersion::v1)
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
return GridwiseGemmPipeline_v4<NumPrefetch>{};
|
||||
}
|
||||
else if constexpr(PipelineVer == PipelineVersion::weight_only)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1>
|
||||
struct GridwiseGemmPipeline_v1<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
|
||||
|
||||
// 2-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, true, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly;
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename ScaleGridDesc,
|
||||
typename ScaleGridBuffer,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const ScaleGridDesc& scale_grid_desc,
|
||||
const ScaleGridBuffer& scale_grid_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Global Prefetch Stage 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Scale read once
|
||||
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// Dequantization fused in blockwise_copy
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
struct GridwiseGemmPipelineInterwave_v1;
|
||||
|
||||
@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
|
||||
|
||||
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
|
||||
template <>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
|
||||
@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage, true, true>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
@@ -33,31 +33,27 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_wmma(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
// const
|
||||
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
|
||||
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
const AGridDesc a_grid_desc,
|
||||
const BGridDesc b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
a_grid_desc,
|
||||
b_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -67,8 +63,8 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = a_grid_desc;
|
||||
ignore = b_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
@@ -78,21 +74,21 @@ __global__ void
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t K1Value,
|
||||
@@ -105,6 +101,7 @@ template <index_t BlockSize,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool AEnableLds,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -113,6 +110,7 @@ template <index_t BlockSize,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BEnableLds,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
@@ -121,7 +119,7 @@ template <index_t BlockSize,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
struct GridwiseGemm_Wmma
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
// FIX ME: To be deprecated
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe =
|
||||
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
AEnableLds,
|
||||
BEnableLds>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
// Describe how data store to (LDS/VGPR) buffer from Global memory
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
constexpr auto a_block_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->M->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0perblock_mperblock_k1;
|
||||
return a_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
constexpr auto b_block_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->N->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0perblock_nperblock_k1;
|
||||
return b_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto a_block_copy_step = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_copy_step;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto b_block_copy_step = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_copy_step;
|
||||
}
|
||||
|
||||
// Describe how data read from (LDS/VGPR) buffer
|
||||
template <typename ABlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
|
||||
{
|
||||
|
||||
constexpr auto a_wave_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
|
||||
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Err: merge transform cause non-constexpr issue
|
||||
|
||||
// return transform_tensor_descriptor(
|
||||
// ABlockDesc_{},
|
||||
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
|
||||
// make_pass_through_transform(Number<MRepeat>{}),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(Number<A_K1>{})),
|
||||
// make_tuple(Sequence<0, 3>{},
|
||||
// Sequence<1>{},
|
||||
// Sequence<2>{},
|
||||
// Sequence<4>{},
|
||||
// Sequence<5>{}),
|
||||
// make_tuple(
|
||||
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
|
||||
// Sequence<4>{}));
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<A_KRow>{},
|
||||
I1,
|
||||
Number<A_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_wave_desc;
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
|
||||
{
|
||||
constexpr auto b_wave_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
|
||||
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<B_KRow>{},
|
||||
I1,
|
||||
Number<B_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_wave_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
|
||||
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
|
||||
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned * sizeof(FloatA) +
|
||||
b_block_space_size_aligned * sizeof(FloatB));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
(NPerBlock % (NRepeat * NPerWmma)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
const auto GetAProblemsizeMK = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
|
||||
a_grid_desc.GetLength(I5),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
|
||||
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto GetBProblemsizeNK = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
|
||||
b_grid_desc.GetLength(I5),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
|
||||
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto M = GetAProblemsizeMK()[I0];
|
||||
const auto N = GetBProblemsizeNK()[I0];
|
||||
const auto K = GetAProblemsizeMK()[I1];
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
K == GetBProblemsizeNK()[I1]))
|
||||
{
|
||||
printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
|
||||
GetAProblemsizeMK()[I0],
|
||||
GetAProblemsizeMK()[I1],
|
||||
GetBProblemsizeNK()[I0],
|
||||
GetBProblemsizeNK()[I1],
|
||||
c_grid_desc_m_n.GetLength(I0),
|
||||
c_grid_desc_m_n.GetLength(I1));
|
||||
printf("GridwiseOp err: ProblemSize check");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
printf("GridwiseOp err: ProblemSize division");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
printf("GridwiseOp err: Pipeline not support this k_loop");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
|
||||
if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
static constexpr auto max_lds_align = K1;
|
||||
|
||||
static constexpr auto a_block_space_size_aligned =
|
||||
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
static constexpr auto b_block_space_size_aligned =
|
||||
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
|
||||
static constexpr auto a_block_space_offset = 0;
|
||||
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
static constexpr auto c_shuffle_block_space_size =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
.GetElementSpaceSize();
|
||||
|
||||
static constexpr auto c_shuffle_block_space_offset = 0;
|
||||
|
||||
static constexpr auto lds_size =
|
||||
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
|
||||
a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
constexpr auto max_lds_align = K1;
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
|
||||
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
|
||||
const auto K = [&](){
|
||||
if constexpr(AEnableLds){
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else{
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
|
||||
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_block_desc = MakeABlockDescriptor();
|
||||
constexpr auto b_block_desc = MakeBBlockDescriptor();
|
||||
|
||||
auto a_block_trait = [&](){
|
||||
// A matrix blockwise copy
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared),
|
||||
SharedMemTrait::a_block_space_size_aligned);
|
||||
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
|
||||
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
|
||||
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
|
||||
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
|
||||
/* typename SrcData, */ FloatA,
|
||||
/* typename DstData, */ FloatA,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
|
||||
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
/* typename SrcData, */ ADataType,
|
||||
/* typename DstData, */ ADataType,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc),
|
||||
/* typename DstDesc, */ decltype(a_block_desc),
|
||||
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
|
||||
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
|
||||
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
|
||||
@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/* index_t SrcScalarStrideInVector, */ 1,
|
||||
/* index_t DstScalarStrideInVector, */ 1,
|
||||
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
a_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
|
||||
a_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto a_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc),
|
||||
decltype(a_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0,
|
||||
m_block_data_idx_on_grid/(MWaves * MPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto b_block_trait = [&](){
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
|
||||
SharedMemTrait::b_block_space_size_aligned);
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0,
|
||||
n_block_data_idx_on_grid/(NWaves * NPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto a_block_buf = a_block_trait()[I0];
|
||||
auto a_blockwise_copy = a_block_trait()[I1];
|
||||
|
||||
auto b_block_buf = b_block_trait()[I0];
|
||||
auto b_blockwise_copy = b_block_trait()[I1];
|
||||
/*******************************************************************************/
|
||||
// GEMM
|
||||
constexpr auto WmmaK = 16;
|
||||
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
BlockwiseGemmWMMA<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
decltype(MakeAWaveDescriptor(a_block_desc)),
|
||||
decltype(MakeBWaveDescriptor(b_block_desc)),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
AEnableLds,
|
||||
BEnableLds>{};
|
||||
|
||||
// Prepare Register for C matrix
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
/*******************************************************************************/
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
|
||||
|
||||
/*******************************************************************************/
|
||||
// Shift Per SUB_K
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
|
||||
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
KBlockMainLoop);
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
|
||||
SharedMemTrait::c_shuffle_block_space_size);
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatCShuffle,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
Reference in New Issue
Block a user