mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
* Squashed commit of the following: commit 3e1a851dad834776efbe4fe365ac82c4ed312010 Author: Ding, Yi <yi.ding@amd.com> Date: Thu Oct 23 06:10:54 2025 +0000 Fix & clean after rebase commit 1edf485092f44411da9a1796a4a6b72d5cdb67c6 Author: Ding, Yi <yi.ding@amd.com> Date: Wed Oct 22 10:46:13 2025 +0000 Squashed commit of the following: commit 5276b28a51dac7b5d2106fbae8e78de190ee0de1 Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 22 02:04:27 2025 -0500 fix bandwidth calculation commit d645bb20c6d879154c30ecd82bbff4d2a9206750 Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 22 00:58:59 2025 -0500 updates commit 0fa7e6b88aaf81a36034aa7607746de295de4263 Author: mtgu0705 <mtgu@amd.com> Date: Fri Sep 19 00:39:46 2025 -0500 fix a bug, set the A DS_read preload size to 4 for MXFP4 commit 50cafa824e2267f2b2f0dfeeb93e69a673630c61 Author: mtgu0705 <mtgu@amd.com> Date: Thu Sep 18 01:19:03 2025 -0500 fix a_wrap preload issue for large MPerBlock. commit e6333bbbc6ef540e24f92095040085f1ed59041e Author: mtgu0705 <mtgu@amd.com> Date: Wed Sep 17 21:34:03 2025 -0500 optimized the VGPR repack issue for MXFP4 commit e99e4932c401b9f6d1893dd5044c2827d6b3f145 Author: Gino Lu <gino.lu@amd.com> Date: Wed Sep 17 04:19:44 2025 -0500 fix time error commit 4586ce6da7fba0514f2e01a8124c76b7d494e124 Author: mtgu0705 <mtgu@amd.com> Date: Wed Sep 17 03:58:00 2025 -0500 updated, function passed. commit c4f25e7579573db5681b9160f6bdb1349f3566f1 Author: mtgu0705 <mtgu@amd.com> Date: Tue Sep 16 22:21:39 2025 -0500 fix, function partially passed commit a51b56eb6b00b99a4e8d2802dbf5b5b5277b54d8 Author: mtgu0705 <mtgu@amd.com> Date: Tue Sep 16 03:01:12 2025 -0500 fix, reference function passed, next check kernel function commit 5b02643ebab18960e8f9ba66c6bd2f91774f9cae Author: Gino Lu <gino.lu@amd.com> Date: Tue Sep 16 02:29:01 2025 -0500 let pack/unpack return pk_fp4_t commit 76d37c5d4b17530e95c6fced31bff66a35d54b8f Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 15 20:50:26 2025 -0500 fix commit e5be3e162b9a20e5355bd556d2b27afb6d8bf085 Author: Gino Lu <gino.lu@amd.com> Date: Mon Sep 15 05:51:06 2025 -0500 fix bug commit 39a024efe4aa773df589712b1290803bb5ab5d1d Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 15 04:02:05 2025 -0500 fix core dump issue, function is not correct. commit 16c49d268cfe065b5112b960b2d852b26552686a Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 15 03:03:02 2025 -0500 updates, build pass commit fe7a961852dee6eff3be3cf1e0d0fabec5cd42ee Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 15 00:05:18 2025 -0500 updates commit aaf9fe8022a72df59e04e4d5886dca3ba9c23400 Author: Gino Lu <gino.lu@amd.com> Date: Sun Sep 14 23:40:28 2025 -0500 fix bug commit a3da89290e1553b85fbf1171c07e93ac0f5584db Author: Gino Lu <gino.lu@amd.com> Date: Fri Sep 12 03:28:50 2025 -0500 fix interface commit c5ff747e72d877461ba61dc19a0fe15527d3161e Author: Gino Lu <gino.lu@amd.com> Date: Fri Sep 12 02:53:50 2025 -0500 add interface in warp_gemm_impl commit 0a48d369e601cc798589fc59e0784bdbfc0a22f9 Author: mtgu0705 <mtgu@amd.com> Date: Wed Sep 10 05:03:08 2025 -0500 updates some fixes. commit aaa2beca30ff5546d171a2028d1894fd4e131d4e Author: mtgu0705 <mtgu@amd.com> Date: Tue Sep 9 04:37:42 2025 -0500 fix after merge ginolu/add_wgmfma_dispatcher commit bf87449b09cba690922b2f3f78ba39bf1b1e472e Merge: 05ab58e3d 991d7fdbb Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 8 22:09:15 2025 -0500 Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev commit 05ab58e3de2b708aceda63d704089c0fa89437ae Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 8 21:42:47 2025 -0500 update mx flatmm tail pipeline commit 991d7fdbb726d65091a91b5cc2800f798a6661fc Merge: ad046084ab2f280046Author: Gino Lu <gino.lu@amd.com> Date: Mon Sep 8 19:10:23 2025 -0500 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit ad046084a2f6e4ebf0cd8b47d0d72b74815061fa Author: Gino Lu <gino.lu@amd.com> Date: Mon Sep 8 19:09:55 2025 -0500 fix type error commit 42e16b43a035364a42789d7ce45a1e6a7d1d2609 Author: mtgu0705 <mtgu@amd.com> Date: Mon Sep 8 04:01:40 2025 -0500 update hotloop pipeline commit c2f69745346545087c8ce24acaba2961bb93ef0b Merge: adbeeb90b8b4be3a0eAuthor: Gino Lu <gino.lu@amd.com> Date: Fri Sep 5 04:22:26 2025 -0500 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit adbeeb90be1533f8aeb8c1d5aea6470d45a455a0 Author: Gino Lu <gino.lu@amd.com> Date: Fri Sep 5 04:21:26 2025 -0500 fix clang format commit e2378ac393bb79ac80a8eef84677bffce86d9e0a Author: mtgu0705 <mtgu@amd.com> Date: Wed Sep 3 10:00:54 2025 -0500 some updates commit bdc18a2269db49ff88e1ef1db30f83ea430d7544 Merge: 6c5cea2b7feec59755Author: asleepzzz <hanwen.chang@amd.com> Date: Wed Sep 3 13:22:03 2025 +0800 Merge branch 'develop' into ginolu/add_wgmfma_dispatcher commit 6c5cea2b7a306f5d0ad346cb9baf6370ea2a73fe Author: Gino Lu <gino.lu@amd.com> Date: Mon Sep 1 02:11:02 2025 -0500 fix vec size error commit 76d1dfa352087dfd5867c8909b73726d3a1e853e Author: Gino Lu <gino.lu@amd.com> Date: Mon Sep 1 01:23:39 2025 -0500 fix format error commit a9061aaa1b4bfaa9db102c75b9d74863f39708a9 Author: mtgu0705 <mtgu@amd.com> Date: Sat Aug 30 03:19:07 2025 -0500 update codes commit 0caa184a271a8824ef40f87de456d0fa2500c8ad Author: mtgu0705 <mtgu@amd.com> Date: Fri Aug 29 11:27:33 2025 -0500 init ck_tile mxfp4 flatmm commit 5d46a6635f04bd69b76f7eda1438862e271b987a Author: Feng Shijie <Shijie.Feng@amd.com> Date: Thu Aug 28 08:02:50 2025 +0000 Add bias for f16xf4 moe_flatmm commit dd112dc302d17f541737671a3ac557d7c09ff969 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 27 13:39:47 2025 +0000 update case construction commit b1aca68a073d82c7b3c7bb53286e5f415999edc1 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Tue Aug 26 12:32:29 2025 +0000 support swiglu activaion and use rcpf to accelerate silu commit 49235bd42349a84fc2ebd7ad0b100cc2545bb80a Author: Gino Lu <gino.lu@amd.com> Date: Tue Aug 26 02:33:55 2025 -0500 first commit commit c169e39d6381b932cf7098cc118db29df91da1cb Author: root <root@smci355-ccs-aus-m02-25.cs-aus.dcgpu> Date: Fri Aug 22 04:01:59 2025 -0500 add line to last commit 318f9bf317306454941bbf394c1940023edcf0ac Author: root <root@smci355-ccs-aus-m02-25.cs-aus.dcgpu> Date: Fri Aug 22 03:20:46 2025 -0500 adjust A_LDS descriptor to avoid bankconflict commit 9d066120ed068d6d102da25d619e170a28a04d18 Author: root <root@smci355-ccs-aus-m02-25.cs-aus.dcgpu> Date: Thu Aug 21 09:46:52 2025 -0500 enable hotloop commit 61a895e6b821798970afffd0e9432a21e2f04df8 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Thu Aug 21 09:12:21 2025 +0000 support atomic_pk_add_bf16 on gfx950 commit 9f14864e45f21d8c1bc70a94988fb86c2c0017d8 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Thu Aug 21 06:58:55 2025 +0000 use int64_t as expert stride to avoid overflow commit e63af46b32e1139a1e59dee6f46b9971047c4026 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 20 13:53:32 2025 +0000 use v4i32 as the storage type for B to avoid repack operation commit 6cf0224dd8a229bf2be726ca861c736c9b5f5415 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 20 06:40:03 2025 +0000 add pk_fp4_t and e8m0_t support for amd_buffer_load_impl commit 67a591f2240b0b035029edad904627f98b3839fd Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 20 04:39:14 2025 +0000 optimize cvt_pkf4_to_f16 implementation commit 51c7126e77e9b17af694eaa57040e487f9d443e8 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Tue Aug 19 14:56:46 2025 +0000 optimize A_LDS descriptor to avoid bankconflict commit c113160f326353290a2878d7b8febf7daed91d71 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 18 18:43:37 2025 +0000 fix gate-up when GU_NRepeat > 1 commit a45ca0e9934ca4bb9114f65621d5c9582d937a45 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 18 17:28:11 2025 +0000 add fp16xf4 moe commit dc8c8e484804f7bca10c8f0764540af3b5884e83 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Sun Aug 17 17:51:18 2025 +0000 rename example commit b177c967141cfdc401d3f36bf17830fe99893600 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Fri Aug 15 06:20:46 2025 +0000 remove additional check when e8m0->float commit d467f9688c3d35f391e15089135edb1ad1d38b05 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Thu Aug 14 09:34:12 2025 +0000 eliminate repeat dequant commit 1b20674b26ab3ce6bd2f710dd729fd4cc0f79428 Merge: faa3c0278 7d02625e7 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 13 16:51:49 2025 +0000 Merge remote-tracking branch 'origin/moe_flatmm' into feat-mixed_input_flatmm commit faa3c0278cf11b7105a4302dea3a4416520b2cc7 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 13 16:16:48 2025 +0000 update f16xMXF4 commit a2a2e1dab05501cc2136133236c01c08d51db4ea Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 13 10:48:53 2025 +0000 update scale-preshuffle for MXF4 commit eac9667feb899419dda1628164c092b969852660 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 11:24:34 2025 +0000 update commit 7d02625e7678882af653f52c2a4ddaf64568a41c Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 08:38:23 2025 +0000 optimize gemm2 atomic_add pattern commit d5f3c3e3ec72d0e6739467c4dc0b4e209f6d1192 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 07:59:47 2025 +0000 update scale for mxfp4 commit 15db198084614466bd4cfd4943fcb549cab2069a Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 07:56:14 2025 +0000 update case construction commit 5dff349d82a5f70b6eea821d2622df51f90ef200 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 06:03:06 2025 +0000 update granularity control commit d32cdc52144f65ec473f4ec8e45ea23968811184 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 11 03:42:46 2025 +0000 fix TileConfig commit 26f38c5716304ee5f84e5c4f6f88144d9f3dddaf Author: Gino Lu <gino.lu@amd.com> Date: Thu Aug 7 21:37:28 2025 +0800 Add e8m0 scaled convert into CK_TILE (#2617) * first commit * remove redundent code * modify according to comments. * fix type_convert error with scaled_type_convert commit 419041478745f65dfec18859e75a13d975089519 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Fri Aug 8 20:19:16 2025 +0000 add mixed_prec fp16xfp4 commit 92e2a8b0308b9b107df9d2fd63a961efce706402 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Thu Aug 7 09:22:04 2025 +0000 debug mixed_prec flatmm commit dea3ce80496ebcb00512979f0c3bb897f25e11a5 Merge: fde443bc3 b4f45fe14 Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Wed Aug 6 16:49:47 2025 +0800 Merge pull request #2626 from ROCm/felix/flatmm_fix_splitk fix split k commit d480e8150358cc4ef8b05e25afe299141fad4fde Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Aug 6 08:33:33 2025 +0000 add moe_flatmm commit b4f45fe14d11569f34de40c8a205cd6760b61357 Author: coderfeli <coderfeli@163.com> Date: Wed Aug 6 02:45:31 2025 +0000 fix split k commit fde443bc38fe60e52195817ecb2c7b20d772eedb Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Aug 4 07:16:36 2025 +0000 fix flatmm with scaling when WarpTileM == 32 commit 5a0667afa889a5af8c6b8509232eabd50cf5efef Author: Feng Shijie <Shijie.Feng@amd.com> Date: Fri Aug 1 11:01:23 2025 +0000 optimize scaling epilogue commit 5c3502bbf71833c6f6f7d4a1cc4f4fd93811f522 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Fri Aug 1 07:28:38 2025 +0000 fix wrong config for fp8 scaling commit eb2d0653cdb86603cb11539cbac466b6431b58b7 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Jul 30 06:20:30 2025 +0000 prune debug message commit 0c089cb56343a39e02a1ee38e9cabeb71ba35e92 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Jul 30 04:52:08 2025 +0000 fix compile error commit 61759ca30ce3787f70e228c3919b3e4d354016dd Author: Feng Shijie <Shijie.Feng@amd.com> Date: Tue Jul 29 15:42:58 2025 +0000 Add persistent option on flatmm for tuning commit b36dc5dd55f15fc1ce8eb21637bdec862e56a883 Author: AMD-dteng <dteng@amd.com> Date: Tue Jul 29 22:48:00 2025 +0800 update pipeline v1: add atomic IGLP schedule commit f886f26994454fc2b4fc3433c86bf699767a2a7c Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Thu Jul 24 09:09:27 2025 +0000 fix error log throwing commit 4b4686ab144daa9061fbda17f3df4c17600c8e9a Author: Feng Shijie <Shijie.Feng@amd.com> Date: Mon Jul 28 08:24:51 2025 +0000 crz idea commit 7099af44a81be41431ba70ae60827b60116d02d2 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Sun Jul 27 11:57:38 2025 +0000 Add permuteN optimzization when NRepeat % 2 == 0 on flatmm commit b147524c92e69a267337c8e48b6e64bcb1483551 Author: sjfeng <j514681085@icloud.com> Date: Sun Jul 27 17:24:08 2025 +0800 try to remove c_shuffle_lds commit 2dd94f59d1a7740a5689e1713ed45588cd0d55dd Author: Feng Shijie <Shijie.Feng@amd.com> Date: Fri Jul 25 07:41:48 2025 +0000 fix loop-dim mismatch and improve c_shuffle alu parallelism commit 4e93f0c5e27806adc070e4caa81661069295751c Merge: 3f12ef5aa 0eb7455f1 Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Thu Jul 24 08:46:51 2025 +0000 merge flatmm -scale commit 3f12ef5aa52ced1bff3bfb57b878358330e9e095 Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Thu Jul 24 16:19:58 2025 +0800 revert delete of inc file commit 08c3a0d184d7581dc5be364f5b36f16fb4a8d6fa Author: solin <bingzhou@amd.com> Date: Thu Jul 24 04:38:16 2025 +0000 reorg flatmm code commit 0eb7455f106604d5254ed16b0daeda68e2a148e3 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Jul 23 19:12:31 2025 +0000 fix flatmm syntax error on gfx950 commit 695ff87e68fdcbe28452c1805cd4dbb643c45495 Author: Feng Shijie <Shijie.Feng@amd.com> Date: Wed Jul 23 19:04:22 2025 +0000 support flatmm scaling commit e3c29d9dea8758db96b998982ccc8bd1c4e8298d Author: valarLip <340077269@qq.com> Date: Wed Jul 23 08:44:12 2025 +0000 merge flatmm pipe v0 from dteng_flatmm_opt commit 425c366fa4c30426ff36cade89b39fd8cb7b9732 Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Wed Jul 23 15:38:12 2025 +0800 build pass commit 6b377a9481535696de40f175d7e2159263d21bdc Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Wed Jul 23 07:20:26 2025 +0000 fix bug commit b6dc58d1ea676fe480c0243ae098c875498f6d6a Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Wed Jul 23 15:01:53 2025 +0800 sync commit 904359f401866ee810484e6b8f5b46d79d9e25c8 Author: valarLip <340077269@qq.com> Date: Tue Jul 22 08:09:35 2025 +0000 adaptive scheduler instead of Macro definition commit f29916c17228c17de9923aab62e7d72d7a30f4e9 Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Thu Jul 17 08:40:35 2025 +0000 fix tail handler bug commit e2c60a90929fec955d91db909d50db538d58363b Author: lalala-sh <Jiaxing.Wen@amd.com> Date: Wed Jul 16 10:12:19 2025 +0000 merge from dteng_flatmm_opt --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: AMD-dteng <dteng@amd.com> Co-authored-by: solin <bingzhou@amd.com> Co-authored-by: sjfeng <j514681085@icloud.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com> Co-authored-by: Feng Shijie <Shijie.Feng@amd.com> Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: Gino Lu <gino.lu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> * Fix crash on small M * Apply suggestion from @Copilot --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: AMD-dteng <dteng@amd.com> Co-authored-by: solin <bingzhou@amd.com> Co-authored-by: sjfeng <j514681085@icloud.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com> Co-authored-by: Feng Shijie <Shijie.Feng@amd.com> Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: Gino Lu <gino.lu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> [ROCm/composable_kernel commit:e135dd518d]
942 lines
38 KiB
C++
942 lines
38 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/common.hpp"
|
|
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
|
|
|
namespace ck_tile {
|
|
struct FlatmmProblem
|
|
{
|
|
CK_TILE_HOST FlatmmProblem() = default;
|
|
CK_TILE_HOST FlatmmProblem(
|
|
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
|
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
|
{
|
|
}
|
|
|
|
index_t M;
|
|
index_t N;
|
|
index_t K;
|
|
index_t stride_A;
|
|
index_t stride_B;
|
|
index_t stride_C;
|
|
};
|
|
|
|
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
|
struct FlatmmScalePointer
|
|
{
|
|
static constexpr int GranularityMN = SharedGranularityMN;
|
|
static constexpr int GranularityK = SharedGranularityK;
|
|
|
|
const float* ptr;
|
|
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
|
: ptr(ptr_)
|
|
{
|
|
}
|
|
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
|
{
|
|
FlatmmScalePointer ret;
|
|
if constexpr(GranularityMN == 0)
|
|
{
|
|
ret.ptr = ptr + offset / GranularityK;
|
|
}
|
|
else
|
|
{
|
|
ret.ptr = ptr + offset / GranularityMN / GranularityK;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
|
};
|
|
|
|
template <int SharedGranularityMN>
|
|
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
|
{
|
|
static constexpr int GranularityMN = SharedGranularityMN;
|
|
static constexpr int GranularityK = 0;
|
|
|
|
static_assert(GranularityMN != 0);
|
|
|
|
const float* ptr;
|
|
index_t length;
|
|
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
|
: ptr(ptr_), length(length_)
|
|
{
|
|
}
|
|
|
|
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
|
{
|
|
FlatmmScalePointer ret;
|
|
if constexpr(GranularityMN == 1)
|
|
{
|
|
ret.ptr = ptr + offset;
|
|
ret.length = length - offset;
|
|
}
|
|
else
|
|
{
|
|
ret.ptr = ptr + offset / GranularityMN;
|
|
ret.length = length - offset / GranularityMN;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
|
{
|
|
// with additional oob check
|
|
if constexpr(GranularityMN == 1)
|
|
return i < length ? ptr[i] : 0;
|
|
else
|
|
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
|
|
}
|
|
};
|
|
|
|
// shared granularityMN = -1 means no scale
|
|
template <>
|
|
struct FlatmmScalePointer<-1, 0>
|
|
{
|
|
static constexpr int GranularityMN = -1;
|
|
static constexpr int GranularityK = 0;
|
|
|
|
const float* ptr = nullptr;
|
|
|
|
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
|
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
|
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
|
|
|
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
|
{
|
|
return FlatmmScalePointer{};
|
|
}
|
|
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
|
{
|
|
return 1; // alway return 1, it doesn't change the result
|
|
}
|
|
};
|
|
|
|
template <index_t NumDTensor = 0>
|
|
struct BaseFlatmmHostArgs
|
|
{
|
|
CK_TILE_HOST BaseFlatmmHostArgs() = default;
|
|
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
|
|
const void* b_ptr_,
|
|
const std::array<const void*, NumDTensor>& ds_ptr_,
|
|
void* e_ptr_,
|
|
index_t k_batch_,
|
|
index_t M_,
|
|
index_t N_,
|
|
index_t K_,
|
|
index_t stride_A_,
|
|
index_t stride_B_,
|
|
const std::array<index_t, NumDTensor>& stride_Ds_,
|
|
index_t stride_E_)
|
|
: a_ptr(a_ptr_),
|
|
b_ptr(b_ptr_),
|
|
ds_ptr(ds_ptr_),
|
|
e_ptr(e_ptr_),
|
|
M(M_),
|
|
N(N_),
|
|
K(K_),
|
|
stride_A(stride_A_),
|
|
stride_B(stride_B_),
|
|
stride_Ds(stride_Ds_),
|
|
stride_E(stride_E_),
|
|
k_batch(k_batch_)
|
|
{
|
|
}
|
|
|
|
const void* a_ptr;
|
|
const void* b_ptr;
|
|
const std::array<const void*, NumDTensor> ds_ptr;
|
|
union
|
|
{
|
|
void* e_ptr;
|
|
void* c_ptr;
|
|
};
|
|
index_t M;
|
|
index_t N;
|
|
index_t K;
|
|
index_t stride_A;
|
|
index_t stride_B;
|
|
const std::array<index_t, NumDTensor> stride_Ds;
|
|
union
|
|
{
|
|
index_t stride_E;
|
|
index_t stride_C;
|
|
};
|
|
|
|
index_t k_batch;
|
|
};
|
|
template <class ScaleM = FlatmmScalePointer<-1>,
|
|
class ScaleN = FlatmmScalePointer<-1>,
|
|
index_t NumDTensor = 0>
|
|
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
|
|
{
|
|
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
|
|
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
|
|
const void* b_shuffle_ptr_,
|
|
const std::array<const void*, NumDTensor>& ds_ptr_,
|
|
void* c_ptr_,
|
|
index_t k_batch_,
|
|
index_t M_,
|
|
index_t N_,
|
|
index_t K_,
|
|
index_t stride_A_,
|
|
index_t stride_B_,
|
|
const std::array<index_t, NumDTensor>& stride_Ds_,
|
|
index_t stride_C_,
|
|
ScaleM scale_m_ = nullptr,
|
|
ScaleN scale_n_ = nullptr)
|
|
: BaseFlatmmHostArgs(a_ptr_,
|
|
b_shuffle_ptr_,
|
|
ds_ptr_,
|
|
c_ptr_,
|
|
k_batch_,
|
|
M_,
|
|
N_,
|
|
K_,
|
|
stride_A_,
|
|
stride_B_,
|
|
stride_Ds_,
|
|
stride_C_),
|
|
scale_m(scale_m_),
|
|
scale_n(scale_n_)
|
|
{
|
|
}
|
|
ScaleM scale_m = nullptr;
|
|
ScaleN scale_n = nullptr;
|
|
};
|
|
|
|
template <int NumberTensor = 0>
|
|
using FlatmmHostArgs =
|
|
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
|
|
|
|
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
|
|
struct FlatmmKernelArgs
|
|
{
|
|
const void* a_ptr;
|
|
// const void* b_shuffle_ptr;
|
|
const void* b_ptr;
|
|
const std::array<const void*, NumDTensor> ds_ptr;
|
|
void* e_ptr;
|
|
index_t M;
|
|
index_t N;
|
|
index_t K;
|
|
index_t stride_A;
|
|
index_t stride_B;
|
|
std::array<index_t, NumDTensor> stride_Ds;
|
|
index_t stride_E;
|
|
index_t k_batch;
|
|
ScaleM scale_m_ptr = nullptr;
|
|
ScaleN scale_n_ptr = nullptr;
|
|
};
|
|
|
|
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
|
struct FlatmmKernel
|
|
{
|
|
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
|
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
|
using BlockGemmShape =
|
|
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
|
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
|
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
|
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
|
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
|
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
|
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
|
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
|
|
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
|
|
|
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
|
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
|
// Below type is actually accumulation data type - the output of block GEMM.
|
|
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
|
|
|
static constexpr index_t NumDTensor = DsDataType::size();
|
|
|
|
static constexpr auto I0 = number<0>();
|
|
static constexpr auto I1 = number<1>();
|
|
static constexpr auto I2 = number<2>();
|
|
static constexpr auto I3 = number<3>();
|
|
|
|
static_assert(DsLayout::size() == DsDataType::size(),
|
|
"The size of DsLayout and DsDataType should be the same");
|
|
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
|
|
|
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
|
{
|
|
// clang-format off
|
|
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
|
// clang-format on
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
|
{
|
|
assert(!UsePersistentKernel);
|
|
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
|
}
|
|
|
|
template <class ScaleM, class ScaleN>
|
|
CK_TILE_HOST static constexpr auto
|
|
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
|
{
|
|
if constexpr(UsePersistentKernel)
|
|
{
|
|
hipDeviceProp_t prop;
|
|
int deviceId = 0; // default device
|
|
|
|
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
|
int dync_smem_size = 0;
|
|
int maxActiveBlocksPerCU = 0;
|
|
|
|
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
|
|
|
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&maxActiveBlocksPerCU,
|
|
reinterpret_cast<void*>(
|
|
kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
|
block_size,
|
|
dync_smem_size);
|
|
|
|
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
|
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
|
|
|
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
|
// << ", persistent_block_size: " << persistent_block_size
|
|
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
|
|
|
assert(kargs.k_batch == 1);
|
|
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
|
}
|
|
else
|
|
{
|
|
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
|
}
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
|
|
|
template <class ScaleM, class ScaleN>
|
|
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
|
|
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
|
{
|
|
return {hostArgs.a_ptr,
|
|
hostArgs.b_ptr,
|
|
hostArgs.ds_ptr,
|
|
hostArgs.e_ptr,
|
|
hostArgs.M,
|
|
hostArgs.N,
|
|
hostArgs.K,
|
|
hostArgs.stride_A,
|
|
hostArgs.stride_B,
|
|
hostArgs.stride_Ds,
|
|
hostArgs.stride_E,
|
|
hostArgs.k_batch,
|
|
hostArgs.scale_m,
|
|
hostArgs.scale_n};
|
|
}
|
|
|
|
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
|
{
|
|
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
|
}
|
|
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
|
|
{
|
|
return FlatmmPipeline::GetSmemSize();
|
|
}
|
|
|
|
struct SplitKBatchOffset
|
|
{
|
|
template <class KernelArgs>
|
|
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
|
{
|
|
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
|
|
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
|
const index_t K_t = kargs.k_batch * K1;
|
|
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
|
|
|
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
|
{
|
|
a_k_split_offset = k_id * KRead;
|
|
}
|
|
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
|
{
|
|
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
|
}
|
|
|
|
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
|
{
|
|
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
|
|
}
|
|
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
|
{
|
|
b_k_split_offset = k_id * KRead * N1;
|
|
}
|
|
|
|
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
|
{
|
|
splitted_k = KRead;
|
|
}
|
|
else
|
|
{
|
|
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
|
}
|
|
}
|
|
|
|
index_t a_k_split_offset;
|
|
index_t b_k_split_offset;
|
|
index_t splitted_k;
|
|
};
|
|
|
|
template <class KernelArgs>
|
|
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
|
{
|
|
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
|
is_any_of<EDataType, fp16_t, bf16_t>::value)
|
|
{
|
|
if(kargs.k_batch != 1)
|
|
{
|
|
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
if constexpr(UsePersistentKernel)
|
|
{
|
|
if(kargs.k_batch != 1)
|
|
{
|
|
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
|
|
{
|
|
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
|
|
{
|
|
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
|
|
{
|
|
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
|
|
{
|
|
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
|
{
|
|
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
|
|
{
|
|
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
|
|
{
|
|
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
|
|
{
|
|
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool DTesnorIsValid = {true};
|
|
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
|
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
|
if(std::is_same_v<DiLayout, ELayout> == false)
|
|
{
|
|
DTesnorIsValid = false;
|
|
}
|
|
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
|
{
|
|
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
|
"NPerBlock without padding!");
|
|
DTesnorIsValid = false;
|
|
}
|
|
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
|
{
|
|
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
|
DTesnorIsValid = false;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
|
|
{
|
|
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
|
"MPerBlock without padding!");
|
|
|
|
DTesnorIsValid = false;
|
|
}
|
|
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
|
{
|
|
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
|
DTesnorIsValid = false;
|
|
}
|
|
}
|
|
});
|
|
|
|
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
|
{
|
|
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
|
{
|
|
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
|
|
{
|
|
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
|
" without padding!"
|
|
<< std::endl;
|
|
return false;
|
|
}
|
|
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
|
{
|
|
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
return DTesnorIsValid;
|
|
}
|
|
|
|
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
|
CK_TILE_DEVICE static auto
|
|
MakeGemmTensorViews(const ADataType* a_ptr,
|
|
const BDataType* b_flat_ptr,
|
|
const std::array<const void*, NumDTensor>& ds_ptr,
|
|
EDataType* e_ptr,
|
|
const KernelArgs& kargs,
|
|
const SplitKBatchOffset& splitk_batch_offset)
|
|
{
|
|
const auto& a_tensor_view = [&]() {
|
|
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global>(
|
|
a_ptr,
|
|
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
|
make_tuple(kargs.stride_A, 1),
|
|
number<FlatmmPipeline::GetVectorSizeA()>{},
|
|
number<1>{});
|
|
}
|
|
else
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global>(
|
|
a_ptr,
|
|
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
|
make_tuple(kargs.stride_A, 1),
|
|
number<FlatmmPipeline::GetVectorSizeA()>{},
|
|
number<1>{});
|
|
}
|
|
}();
|
|
|
|
index_t kFlatK =
|
|
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
|
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
|
const auto& b_flat_tensor_view = [&]() {
|
|
return make_naive_tensor_view<address_space_enum::global>(
|
|
b_flat_ptr,
|
|
make_tuple(kFlatN, kFlatK),
|
|
make_tuple(kFlatK, 1),
|
|
number<FlatmmPipeline::GetVectorSizeB()>{},
|
|
number<1>{});
|
|
}();
|
|
|
|
const auto& ds_tensor_view = generate_tuple(
|
|
[&](auto i) {
|
|
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
|
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
|
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global>(
|
|
static_cast<const DDataType_*>(ds_ptr[i]),
|
|
make_tuple(kargs.M, kargs.N),
|
|
make_tuple(kargs.stride_Ds[i], 1),
|
|
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
|
number<1>{});
|
|
}
|
|
else
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global>(
|
|
static_cast<const DDataType_*>(ds_ptr[i]),
|
|
make_tuple(kargs.N, kargs.M),
|
|
make_tuple(kargs.stride_Ds[i], 1),
|
|
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
|
number<1>{});
|
|
}
|
|
},
|
|
number<NumDTensor>{});
|
|
|
|
// TODO: enable vector write for C in ColMajor
|
|
const auto& e_tensor_view = [&]() {
|
|
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
|
e_ptr,
|
|
make_tuple(kargs.M, kargs.N),
|
|
make_tuple(kargs.stride_E, 1),
|
|
number<EpiloguePipeline::GetVectorSizeC()>{},
|
|
number<1>{});
|
|
}
|
|
else
|
|
{
|
|
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
|
e_ptr,
|
|
make_tuple(kargs.N, kargs.M),
|
|
make_tuple(kargs.stride_E, 1),
|
|
number<1>{},
|
|
number<1>{});
|
|
}
|
|
}();
|
|
|
|
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
|
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
|
|
|
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
|
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
|
|
|
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
|
: 1; // per-token scale
|
|
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
|
: 1; // per-channel scale
|
|
|
|
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
|
"only support per-tensor or per-row scaling");
|
|
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
|
"only support per-tensor or per-column scaling");
|
|
|
|
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
|
kargs.scale_m_ptr.ptr,
|
|
make_tuple(
|
|
kargs.M / ScaleGranularityM,
|
|
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
|
|
make_tuple(scale_stride_m, 0),
|
|
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
|
number<1>{});
|
|
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
|
kargs.scale_n_ptr.ptr,
|
|
make_tuple(
|
|
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
|
kargs.N / ScaleGranularityN),
|
|
make_tuple(0, scale_stride_n),
|
|
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
|
number<1>{});
|
|
|
|
return make_tuple(a_tensor_view,
|
|
b_flat_tensor_view,
|
|
ds_tensor_view,
|
|
e_tensor_view,
|
|
scale_m_view,
|
|
scale_n_view);
|
|
}
|
|
|
|
template <typename TensorView>
|
|
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
|
{
|
|
const auto& a_pad_view = [&]() {
|
|
const auto& a_tensor_view = views.at(I0);
|
|
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return pad_tensor_view(a_tensor_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::KPerBlock>{}),
|
|
sequence<false, FlatmmPipeline::kPadK>{});
|
|
}
|
|
else
|
|
{
|
|
return pad_tensor_view(a_tensor_view,
|
|
make_tuple(number<TilePartitioner::KPerBlock>{},
|
|
number<TilePartitioner::MPerBlock>{}),
|
|
sequence<false, FlatmmPipeline::kPadM>{});
|
|
}
|
|
}();
|
|
|
|
const auto& b_flat_tensor_view = views.at(I1);
|
|
|
|
const auto& ds_pad_view = generate_tuple(
|
|
[&](auto i) {
|
|
const auto& d_tensor_view = views.at(I2);
|
|
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
|
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return pad_tensor_view(d_tensor_view[i],
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::NPerBlock>{}),
|
|
sequence<false, FlatmmPipeline::kPadN>{});
|
|
}
|
|
else
|
|
{
|
|
return pad_tensor_view(d_tensor_view[i],
|
|
make_tuple(number<TilePartitioner::NPerBlock>{},
|
|
number<TilePartitioner::MPerBlock>{}),
|
|
sequence<false, FlatmmPipeline::kPadM>{});
|
|
}
|
|
},
|
|
number<NumDTensor>{});
|
|
|
|
// TODO vector write in for C in ColMajor
|
|
const auto& e_pad_view = [&]() {
|
|
const auto& e_tensor_view = views.at(I3);
|
|
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return pad_tensor_view(e_tensor_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::NPerBlock>{}),
|
|
sequence<false, FlatmmPipeline::kPadN>{});
|
|
}
|
|
else
|
|
{
|
|
return pad_tensor_view(e_tensor_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::NPerBlock>{}),
|
|
sequence<FlatmmPipeline::kPadM, false>{});
|
|
}
|
|
}();
|
|
|
|
return make_tuple(a_pad_view,
|
|
b_flat_tensor_view,
|
|
ds_pad_view,
|
|
e_pad_view,
|
|
views.at(number<4>{}),
|
|
views.at(number<5>{}));
|
|
}
|
|
|
|
template <typename PadView>
|
|
CK_TILE_DEVICE static auto
|
|
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
|
{
|
|
const auto& a_pad_view = views.at(I0);
|
|
const auto& b_flat_pad_view = views.at(I1);
|
|
const auto& ds_pad_view = views.at(I2);
|
|
const auto& e_pad_view = views.at(I3);
|
|
|
|
const auto& a_block_window = [&]() {
|
|
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return make_tile_window(a_pad_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::KPerBlock>{}),
|
|
{i_m, 0});
|
|
}
|
|
else
|
|
{
|
|
return make_tile_window(a_pad_view,
|
|
make_tuple(number<TilePartitioner::KPerBlock>{},
|
|
number<TilePartitioner::MPerBlock>{}),
|
|
{0, i_m});
|
|
}
|
|
}();
|
|
|
|
const auto& b_flat_block_window =
|
|
make_tile_window(b_flat_pad_view,
|
|
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
|
number<FlatmmPipeline::flatKPerWarp>{}),
|
|
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
|
|
|
const auto ds_block_window = generate_tuple(
|
|
[&](auto i) {
|
|
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
|
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return make_tile_window(ds_pad_view[i],
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number<TilePartitioner::NPerBlock>{}),
|
|
{i_m, i_n});
|
|
}
|
|
else
|
|
{
|
|
return make_tile_window(ds_pad_view[i],
|
|
make_tuple(number<TilePartitioner::NPerBlock>{},
|
|
number<TilePartitioner::MPerBlock>{}),
|
|
{i_n, i_m});
|
|
}
|
|
},
|
|
number<NumDTensor>{});
|
|
|
|
auto e_block_window = make_tile_window(
|
|
e_pad_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
|
{i_m, i_n});
|
|
|
|
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
|
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
|
|
|
auto scale_m_window = make_tile_window(views.at(number<4>{}),
|
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
|
number < ScaleGranularityKA == 0
|
|
? TilePartitioner::NPerBlock
|
|
: TilePartitioner::KPerBlock > {}),
|
|
{i_m, 0});
|
|
auto scale_n_window = make_tile_window(views.at(number<5>{}),
|
|
make_tuple(number < ScaleGranularityKB == 0
|
|
? TilePartitioner::MPerBlock
|
|
: TilePartitioner::KPerBlock > {},
|
|
number<TilePartitioner::NPerBlock>{}),
|
|
{0, i_n});
|
|
|
|
return make_tuple(a_block_window,
|
|
b_flat_block_window,
|
|
ds_block_window,
|
|
e_block_window,
|
|
scale_m_window,
|
|
scale_n_window);
|
|
}
|
|
|
|
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
|
CK_TILE_DEVICE static void
|
|
RunFlatmm(const ADataType* a_ptr,
|
|
const BDataType* b_flat_ptr,
|
|
const std::array<const void*, NumDTensor>& ds_ptr,
|
|
EDataType* e_ptr,
|
|
void* smem_ptr_ping,
|
|
void* smem_ptr_pong,
|
|
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
|
const SplitKBatchOffset& splitk_batch_offset,
|
|
const index_t block_idx_m,
|
|
const index_t block_idx_n)
|
|
{
|
|
// Create Gemm tensor views, pad views and tile windows
|
|
const auto& gemm_tensor_views_tuple =
|
|
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
|
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
|
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
|
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
|
|
|
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
|
|
|
// Run GEMM cooperatively by whole workgroup.
|
|
const auto& a_block_window = gemm_tile_windows.at(I0);
|
|
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
|
const auto& d_block_window = gemm_tile_windows.at(I2);
|
|
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
|
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
|
|
|
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
|
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
|
|
|
// Run Epilogue Pipeline
|
|
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
|
{
|
|
auto& c_block_window = gemm_tile_windows.at(I3);
|
|
EpiloguePipeline{}.template
|
|
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
|
c_block_window,
|
|
c_block_tile,
|
|
d_block_window,
|
|
smem_ptr_ping,
|
|
scale_m_window,
|
|
scale_n_window);
|
|
}
|
|
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
|
{
|
|
// Run Epilogue Pipeline
|
|
auto& c_block_window = gemm_tile_windows.at(I3);
|
|
EpiloguePipeline{}.template
|
|
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
|
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
|
}
|
|
}
|
|
|
|
template <class ScaleM, class ScaleN>
|
|
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
|
int partition_idx = blockIdx.x) const
|
|
{
|
|
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
|
|
|
do
|
|
{
|
|
const auto [iM, iN] =
|
|
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
|
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
|
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
|
|
|
const SplitKBatchOffset splitk_batch_offset(kargs);
|
|
// options
|
|
const ADataType* a_ptr =
|
|
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
|
const BDataType* b_flat_ptr =
|
|
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
|
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
|
|
|
// allocate LDS
|
|
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
|
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
|
|
|
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
|
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
|
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
|
{
|
|
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
|
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
|
b_flat_ptr,
|
|
kargs.ds_ptr,
|
|
e_ptr,
|
|
smem_ptr_ping,
|
|
smem_ptr_pong,
|
|
kargs,
|
|
splitk_batch_offset,
|
|
i_m,
|
|
i_n);
|
|
}
|
|
partition_idx += gridDim.x;
|
|
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
|
}
|
|
};
|
|
|
|
} // namespace ck_tile
|