mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Wmma support for multiple Ds based GEMMs (#2613)
* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" * Fixed cmake build errors related to test_fp8 * Updates to support mixed precision (cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip (cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F8xF16xF16 to gemm_wmma_universal (cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xF8xF16 to gemm_wmma_universal * Added support for BF16xI4xBF16 to gemm_wmma_universal (cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xI4xF16 to gemm_wmma_universal * Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType * Added missing test class for FP16_KM_NK * Pre-commit hooks fixes * Added padding instances for f16xf16xf16 * Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" (cherry picked from commit5bdc993dbf) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Fixed cmake build errors related to test_fp8 (cherry picked from commit12176616b6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Ammending changes for adding support for padding instances for f16xf16xf16 * Fixes for padding instances for f16xf16xf16 * Added padding instances for bf16xbf16, f8xf8 * Added packed instances for bf16xi4xbf16 * Added padding instances for f8xf16xf16 * Added padding instances for f16xf8xf16, f16xi4xf16 * Fixed typos for bf16xbf16xbf16 padding instances * Fixed typos for padded instances * Added tests for fp16, KM_KN and KM_NK * Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances. * Fixed typos * Updated the set of tests for FP16 * Updated the set of tests for FP16 * Fix typo * Moved f16xi4 test under the correct data layout group * example for gemm_universal_bf16 * Adding examples for gemm_wmma instances * Added the missing parameters * Fixed review comments and added executable to cmakeLists * Fixing clang format * Fixing build erros * Fixed compilation failure. * Modified some code as per gemm_universal_examples * Fixed the gemm specialization error * Fixed the build errors. * Fix strides of a/b_thread_desc The descriptors are larger than needed (even though the compiler don't alloc registers for unused values). * Load in M/NRepeat dims with thread copy's slice instead of a loop * Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation * Implement Intrawave and Interwave variants of pipeline v1 * Add instances for Interwave and Intrawave v1 * Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0 * Remove instances that are too slow (mostly because of register spilling) * Add a workaround for fp8/bf8->f32 packed conversion issue * Add instances for Interwave and Intrawave v1 * Enable profiling of mixed precision with f8 and int4 on WMMA * Fix segfault in profiler when B is pk_i4_t b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds. * Remove instances that are too slow (mostly because of register spilling) * Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations * Add test case for bf16_i4 * Add missing Regular tests * Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS They take more than 30 seconds * Fix a bug that fp16_i4 validation passes only with PermuteB A permutation required by conversion from pk_i4_t to half_t does not depend on PermuteB, they can be used independently. * Use PermuteB with f16_i4 in most instances (as xdl) Some instances use PermuteB = false for checking correctness. See also the previous commit. * Fix cache flushing for pk_i4 * Add mixed precision examples * Disable all tests and instances with f8 on gfx11 Even though f8_f16 and f16_f8 don't require f8 WMMA instructions, gfx11 still lacks hardware instructions for fast f8->f32 conversion. * Add FP16 KM_NK and KM_KN test suites for XDL These tests were added to common .inc for better testing of WMMA instances * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * removed unnecessary ck parts from compilation * initial gemm_add_multiply instance implementations * fixed profiler help message for gemm_add_multiply * improved multiply_add profiler layout help * fixed template arguments for test instances * added test for gemm_add_multiply * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * switched to splitK interface * log print added to splitk benchmarks * revert main cmake comments * newline change reverted * added add_fastgelu instances * revert unintended change in xdl add_fastgelu * created gemm_add_add_fastgelu instances * created fastegelu instances * added tests for all splitk fastgelus * Added tests. * multiply_add instances created * updates to add_multiply splitk instances * splitk xdl test fixes * added wmma multiply_multiply instances * fixed ONLY_XDL_AND_WMMA_KERNELS tag * Added gemm_add examples for wmma v1 and v3 * fixed / workarounded i8 instances * Modified the v3 code to added one fp16 bxdl instance. * added bf16 xdl instance. * adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * Fixing build errors * Added instances for v3 * Adding instances and executables * Code update of template parameters modified. * Renamed file. * Added tests. * resolved error tests. * Fixing build errors * Updated comments * removed the changes as per the MR review comment. * Updated tests. * fp8 instances - not tested * Restored the Cmake file that was reverted by mistake during rebase. * fixed wmma_op test * Updated comments. * Updated the template parameter description * fixed rdna4 instances * fixed back compatibility on gfx11 * cleanups * fix ckProfiler * one more cmake fix * added fp8 instances * Updated tests to ad BF16 instances as per review comment * Added include file and cleaned up(as per review comment) * Updated and optimized the example code for all types. * Fixed clang format * Resolve "Implement `device_gemm_bilinear` for RDNA4" * test generalization to handle FP16 shuffle better * added missing changes * Added bf16 wmma instance for add_relu * Added f16 wmma instance and corrected bf16 instance errors. * Added instances to Cmake * Modified the template parameters to make the instances work. * Fixed typo in profiler * Added v3 instances for gemm_add_relu * addressed core review comments * Added test for gemm_add_relu wmma instance * Cleaned up the code. * Added examples for gemm_add_relu * Fixing typo to resolve build errors. * Fixes applied to fix the precision loss. * fix billinear test after merge * Removed the old wmma instances. * Added wrapper and renamed the wmma_v3 instances * Updated copyrights and added wrappers. * Fixes applied according to review comments * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter <robin@streamhpc.com> * Removed the old wmma instances. * Updated wrapper for the v3 instances * removed the old wmma examples * Renamed the v3 instances * Deleted the gtest file added by mistake. * Updated thge profiler with wrapper * Fixed test errors. * Fixed the review comments * Fixed the if condition MACROS. * REVERTED THE PROFILER CHANGES * Revert "REVERTED THE PROFILER CHANGES" This reverts commit21cb98546c. * Revert "Fixed test errors." This reverts commit13efcc6fe1. * Revert "Updated thge profiler with wrapper" This reverts commit536f86661d. * Added missing wrapper instances * Updated copyrights. * Fixed typo. * Fixed copyrights. * Updated copyrights. * updated copyrights. * comments on the atomics workaround * fixed cmake comment * Fix bug from merge * clang-format-18 * Fix compilation error * Fix linking error * Fix bug in add and add_relu examples * Fix error including file (typo) * Quick fix to compile examples for different targets * Fix for multi target * implemented f16 and bf16 instances for gemm_silu * addressed review comments * addressed review comments * Fix clang format * Fix clang format --------- Co-authored-by: Anca Hamuraru <anca@streamhpc.com> Co-authored-by: apoorva <apoorva@streamhpc.com> Co-authored-by: Anton Gorenko <anton@streamhpc.com> Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com> Co-authored-by: Cenxuan <cenxuan@streamhpc.com> Co-authored-by: Robin Voetter <robin@streamhpc.com> Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com> Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
@@ -22,9 +22,10 @@ namespace ck {
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM kernel is carrying out following mathematical equation:
|
||||
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations that could be applied on each tensor respectively.
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
@@ -36,18 +37,20 @@ namespace ck {
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
@@ -105,11 +108,12 @@ namespace ck {
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
@@ -123,15 +127,17 @@ namespace ck {
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -161,8 +167,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -173,15 +179,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -211,8 +219,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -223,15 +231,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -261,8 +271,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -297,17 +307,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeCGridDescriptor_M_N;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumDTensor;
|
||||
using typename Base::DsGridPointer;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
|
||||
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << "}" << std::endl;
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
|
||||
<< ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
|
||||
<< "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
|
||||
<< ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__device__ static index_t GetKBlockPerScale() { return 1; }
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
@@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
@@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -20,15 +20,17 @@ namespace ck {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockN, // scale N
|
||||
@@ -60,11 +62,11 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
@@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeCGridDescriptor_M_N;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumDTensor;
|
||||
using typename Base::DsGridPointer;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
@@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
@@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
|
||||
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
|
||||
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
|
||||
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
|
||||
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
@@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const BScaleType* p_b_scale_grid;
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CElementwiseOperation c_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// B Scale grid
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
@@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
@@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
@@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
@@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -92,8 +94,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -112,6 +114,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
@@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
template <typename DELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n[i], MBlock, NBlock);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
@@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
|
||||
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
|
||||
{
|
||||
if(!karg.IsReduceAdd())
|
||||
if(karg.IsAtomicAdd() && karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
|
||||
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename BScaleStruct,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
@@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
@@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
@@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
@@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user