Wmma support for multiple Ds based GEMMs (#2613)

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Updates to support mixed precision


(cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip


(cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F8xF16xF16 to gemm_wmma_universal


(cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xF8xF16 to gemm_wmma_universal

* Added support for BF16xI4xBF16 to gemm_wmma_universal


(cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xI4xF16 to gemm_wmma_universal

* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType

* Added missing test class for FP16_KM_NK

* Pre-commit hooks fixes

* Added padding instances for f16xf16xf16

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"


(cherry picked from commit 5bdc993dbf)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Fixed cmake build errors related to test_fp8


(cherry picked from commit 12176616b6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Ammending changes for adding support for padding instances for f16xf16xf16

* Fixes for padding instances for f16xf16xf16

* Added padding instances for bf16xbf16, f8xf8

* Added packed instances for bf16xi4xbf16

* Added padding instances for f8xf16xf16

* Added padding instances for f16xf8xf16, f16xi4xf16

* Fixed typos for bf16xbf16xbf16 padding instances

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* removed unnecessary ck parts from compilation

* initial gemm_add_multiply instance implementations

* fixed profiler help message for gemm_add_multiply

* improved multiply_add profiler layout help

* fixed template arguments for test instances

* added test for gemm_add_multiply

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* switched to splitK interface

* log print added to splitk benchmarks

* revert main cmake comments

* newline change reverted

* added add_fastgelu instances

* revert unintended change in xdl add_fastgelu

* created gemm_add_add_fastgelu instances

* created fastegelu instances

* added tests for all splitk fastgelus

* Added tests.

* multiply_add instances created

* updates to add_multiply splitk instances

* splitk xdl test fixes

* added wmma multiply_multiply instances

* fixed ONLY_XDL_AND_WMMA_KERNELS tag

* Added gemm_add examples for wmma v1 and v3

* fixed / workarounded i8 instances

* Modified the v3 code to added one fp16 bxdl instance.

* added bf16 xdl instance.

* adding gemm_add wmma_cshuffle and other support


(cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* add instances into camkelists


(cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* This is work in progress, edited the template parameters in order to build

(cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype


(cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* added datatype and use clang-format-12


(cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* Fixing build errors

* Added instances for v3

* Adding instances and executables

* Code update of template parameters modified.

* Renamed file.

* Added tests.

* resolved error tests.

* Fixing build errors

* Updated comments

* removed the changes as per the MR review comment.

* Updated tests.

* fp8 instances - not tested

* Restored the Cmake file that was reverted by mistake during rebase.

* fixed wmma_op test

* Updated comments.

* Updated the template parameter description

* fixed rdna4 instances

* fixed back compatibility on gfx11

* cleanups

* fix ckProfiler

* one more cmake fix

* added fp8 instances

* Updated tests to ad BF16 instances as per review comment

* Added include file and cleaned up(as per review comment)

* Updated and optimized the example code for all types.

* Fixed clang format

* Resolve "Implement `device_gemm_bilinear` for RDNA4"

* test generalization to handle FP16 shuffle better

* added missing changes

* Added bf16 wmma instance for add_relu

* Added f16 wmma instance and corrected bf16 instance errors.

* Added instances to Cmake

* Modified the template parameters to make the instances work.

* Fixed typo in profiler

* Added v3 instances for gemm_add_relu

* addressed core review comments

* Added test for gemm_add_relu wmma instance

* Cleaned up the code.

* Added examples for gemm_add_relu

* Fixing typo to resolve build errors.

* Fixes applied to fix  the precision loss.

* fix billinear test after merge

* Removed the old wmma instances.

* Added wrapper and renamed the wmma_v3 instances

* Updated copyrights and added wrappers.

* Fixes applied according to review comments

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Robin Voetter <robin@streamhpc.com>

* Removed the old wmma instances.

* Updated wrapper for the v3 instances

* removed the old wmma examples

* Renamed the v3 instances

* Deleted the  gtest file added by mistake.

* Updated thge profiler with wrapper

* Fixed test errors.

* Fixed the review comments

* Fixed the if condition MACROS.

* REVERTED THE PROFILER CHANGES

* Revert "REVERTED THE PROFILER CHANGES"

This reverts commit 21cb98546c.

* Revert "Fixed test errors."

This reverts commit 13efcc6fe1.

* Revert "Updated thge profiler with wrapper"

This reverts commit 536f86661d.

* Added missing wrapper instances

* Updated copyrights.

* Fixed typo.

* Fixed copyrights.

* Updated copyrights.

* updated copyrights.

* comments on the atomics workaround

* fixed cmake comment

* Fix bug from merge

* clang-format-18

* Fix compilation error

* Fix linking error

* Fix bug in add and add_relu examples

* Fix error including file (typo)

* Quick fix to compile examples for different targets

* Fix for multi target

* implemented f16 and bf16 instances for gemm_silu

* addressed review comments

* addressed review comments

* Fix clang format

* Fix clang format

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: apoorva <apoorva@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com>
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
Co-authored-by: Robin Voetter <robin@streamhpc.com>
Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-09-05 16:31:08 +02:00
committed by GitHub
parent 83f607e2a6
commit b740380906
112 changed files with 7877 additions and 715 deletions

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
@@ -22,9 +22,10 @@ namespace ck {
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -36,18 +37,20 @@ namespace ck {
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
@@ -105,11 +108,12 @@ namespace ck {
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
@@ -123,15 +127,17 @@ namespace ck {
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -161,8 +167,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -173,15 +179,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -211,8 +219,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -223,15 +231,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -261,8 +271,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -297,17 +307,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
<< ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
<< ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ static index_t GetKBlockPerScale() { return 1; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -20,15 +20,17 @@ namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale N
@@ -60,11 +62,11 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
@@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
}
index_t M;
@@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
const CDEElementwiseOperation cde_element_op;
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
@@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
// NOTE: Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -19,7 +19,7 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
@@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg);
#if defined(__gfx11__)
@@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -92,8 +94,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -112,6 +114,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
@@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
}
template <typename DELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
}
}();
@@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
@@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
NRepeat,
KPack>())>;
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
@@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
is_same<remove_cvref_t<EDataType>, float>::value ||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
if(karg.IsAtomicAdd() && karg.KBatch > 1)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
if(karg.KBatch > 1)
{
return false;
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename BScaleStruct,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
@@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
@@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
cde_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
@@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_cde_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
@@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}