mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Wmma support for multiple Ds based GEMMs (#2613)
* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" * Fixed cmake build errors related to test_fp8 * Updates to support mixed precision (cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip (cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F8xF16xF16 to gemm_wmma_universal (cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xF8xF16 to gemm_wmma_universal * Added support for BF16xI4xBF16 to gemm_wmma_universal (cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xI4xF16 to gemm_wmma_universal * Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType * Added missing test class for FP16_KM_NK * Pre-commit hooks fixes * Added padding instances for f16xf16xf16 * Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" (cherry picked from commit5bdc993dbf) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Fixed cmake build errors related to test_fp8 (cherry picked from commit12176616b6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Ammending changes for adding support for padding instances for f16xf16xf16 * Fixes for padding instances for f16xf16xf16 * Added padding instances for bf16xbf16, f8xf8 * Added packed instances for bf16xi4xbf16 * Added padding instances for f8xf16xf16 * Added padding instances for f16xf8xf16, f16xi4xf16 * Fixed typos for bf16xbf16xbf16 padding instances * Fixed typos for padded instances * Added tests for fp16, KM_KN and KM_NK * Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances. * Fixed typos * Updated the set of tests for FP16 * Updated the set of tests for FP16 * Fix typo * Moved f16xi4 test under the correct data layout group * example for gemm_universal_bf16 * Adding examples for gemm_wmma instances * Added the missing parameters * Fixed review comments and added executable to cmakeLists * Fixing clang format * Fixing build erros * Fixed compilation failure. * Modified some code as per gemm_universal_examples * Fixed the gemm specialization error * Fixed the build errors. * Fix strides of a/b_thread_desc The descriptors are larger than needed (even though the compiler don't alloc registers for unused values). * Load in M/NRepeat dims with thread copy's slice instead of a loop * Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation * Implement Intrawave and Interwave variants of pipeline v1 * Add instances for Interwave and Intrawave v1 * Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0 * Remove instances that are too slow (mostly because of register spilling) * Add a workaround for fp8/bf8->f32 packed conversion issue * Add instances for Interwave and Intrawave v1 * Enable profiling of mixed precision with f8 and int4 on WMMA * Fix segfault in profiler when B is pk_i4_t b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds. * Remove instances that are too slow (mostly because of register spilling) * Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations * Add test case for bf16_i4 * Add missing Regular tests * Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS They take more than 30 seconds * Fix a bug that fp16_i4 validation passes only with PermuteB A permutation required by conversion from pk_i4_t to half_t does not depend on PermuteB, they can be used independently. * Use PermuteB with f16_i4 in most instances (as xdl) Some instances use PermuteB = false for checking correctness. See also the previous commit. * Fix cache flushing for pk_i4 * Add mixed precision examples * Disable all tests and instances with f8 on gfx11 Even though f8_f16 and f16_f8 don't require f8 WMMA instructions, gfx11 still lacks hardware instructions for fast f8->f32 conversion. * Add FP16 KM_NK and KM_KN test suites for XDL These tests were added to common .inc for better testing of WMMA instances * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * removed unnecessary ck parts from compilation * initial gemm_add_multiply instance implementations * fixed profiler help message for gemm_add_multiply * improved multiply_add profiler layout help * fixed template arguments for test instances * added test for gemm_add_multiply * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * switched to splitK interface * log print added to splitk benchmarks * revert main cmake comments * newline change reverted * added add_fastgelu instances * revert unintended change in xdl add_fastgelu * created gemm_add_add_fastgelu instances * created fastegelu instances * added tests for all splitk fastgelus * Added tests. * multiply_add instances created * updates to add_multiply splitk instances * splitk xdl test fixes * added wmma multiply_multiply instances * fixed ONLY_XDL_AND_WMMA_KERNELS tag * Added gemm_add examples for wmma v1 and v3 * fixed / workarounded i8 instances * Modified the v3 code to added one fp16 bxdl instance. * added bf16 xdl instance. * adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * Fixing build errors * Added instances for v3 * Adding instances and executables * Code update of template parameters modified. * Renamed file. * Added tests. * resolved error tests. * Fixing build errors * Updated comments * removed the changes as per the MR review comment. * Updated tests. * fp8 instances - not tested * Restored the Cmake file that was reverted by mistake during rebase. * fixed wmma_op test * Updated comments. * Updated the template parameter description * fixed rdna4 instances * fixed back compatibility on gfx11 * cleanups * fix ckProfiler * one more cmake fix * added fp8 instances * Updated tests to ad BF16 instances as per review comment * Added include file and cleaned up(as per review comment) * Updated and optimized the example code for all types. * Fixed clang format * Resolve "Implement `device_gemm_bilinear` for RDNA4" * test generalization to handle FP16 shuffle better * added missing changes * Added bf16 wmma instance for add_relu * Added f16 wmma instance and corrected bf16 instance errors. * Added instances to Cmake * Modified the template parameters to make the instances work. * Fixed typo in profiler * Added v3 instances for gemm_add_relu * addressed core review comments * Added test for gemm_add_relu wmma instance * Cleaned up the code. * Added examples for gemm_add_relu * Fixing typo to resolve build errors. * Fixes applied to fix the precision loss. * fix billinear test after merge * Removed the old wmma instances. * Added wrapper and renamed the wmma_v3 instances * Updated copyrights and added wrappers. * Fixes applied according to review comments * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter <robin@streamhpc.com> * Removed the old wmma instances. * Updated wrapper for the v3 instances * removed the old wmma examples * Renamed the v3 instances * Deleted the gtest file added by mistake. * Updated thge profiler with wrapper * Fixed test errors. * Fixed the review comments * Fixed the if condition MACROS. * REVERTED THE PROFILER CHANGES * Revert "REVERTED THE PROFILER CHANGES" This reverts commit21cb98546c. * Revert "Fixed test errors." This reverts commit13efcc6fe1. * Revert "Updated thge profiler with wrapper" This reverts commit536f86661d. * Added missing wrapper instances * Updated copyrights. * Fixed typo. * Fixed copyrights. * Updated copyrights. * updated copyrights. * comments on the atomics workaround * fixed cmake comment * Fix bug from merge * clang-format-18 * Fix compilation error * Fix linking error * Fix bug in add and add_relu examples * Fix error including file (typo) * Quick fix to compile examples for different targets * Fix for multi target * implemented f16 and bf16 instances for gemm_silu * addressed review comments * addressed review comments * Fix clang format * Fix clang format --------- Co-authored-by: Anca Hamuraru <anca@streamhpc.com> Co-authored-by: apoorva <apoorva@streamhpc.com> Co-authored-by: Anton Gorenko <anton@streamhpc.com> Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com> Co-authored-by: Cenxuan <cenxuan@streamhpc.com> Co-authored-by: Robin Voetter <robin@streamhpc.com> Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com> Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit:b740380906]
This commit is contained in:
@@ -149,50 +149,105 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
#endif
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is
|
||||
/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter
|
||||
/// KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
|
||||
struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideAScale,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideBScale,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
1, // KBatch
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
#endif
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -40,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
@@ -62,14 +62,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
@@ -272,11 +276,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
@@ -311,7 +317,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -336,17 +342,25 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
index_t BatchStrideC_,
|
||||
index_t Batch_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
cde_element_op_,
|
||||
is_reduce_),
|
||||
Batch(Batch_),
|
||||
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
|
||||
@@ -443,7 +457,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -469,7 +483,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -658,7 +672,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1 /* KBatch */};
|
||||
1, /* KBatch */
|
||||
AElementwiseOperation{},
|
||||
BElementwiseOperation{},
|
||||
CElementwiseOperation{}};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -694,7 +711,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1); // KBatch
|
||||
1,
|
||||
AElementwiseOperation{},
|
||||
BElementwiseOperation{},
|
||||
CElementwiseOperation{}); // KBatch
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -0,0 +1,410 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
|
||||
/// to split the dot product accumulated over the K dimension into multiple working groups.
|
||||
/// The partial products of different workgroups are then reduced using the AtomicAdd
|
||||
/// operation.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
: public DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0];
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
str << std::string(DLayout::name)[0];
|
||||
});
|
||||
str << std::string(ELayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -177,15 +177,16 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
@@ -220,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -230,21 +231,24 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
@@ -275,11 +279,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -295,20 +313,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch);
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -89,11 +89,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
@@ -130,7 +132,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -140,21 +142,24 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
@@ -188,23 +193,25 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
p_b_scale,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -228,12 +235,14 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
StrideScaleB,
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
|
||||
@@ -24,7 +24,8 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -32,6 +33,7 @@ template <typename GridwiseGemm,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -95,8 +97,22 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
size_a_buffer,
|
||||
size_b_buffer,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
@@ -106,9 +122,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
@@ -124,9 +140,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
@@ -149,6 +165,16 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
}
|
||||
}();
|
||||
|
||||
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
|
||||
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
|
||||
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
|
||||
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
|
||||
// be odd.
|
||||
constexpr bool AtomicsImplementationExists =
|
||||
!(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>) ||
|
||||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
@@ -157,12 +183,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -186,12 +215,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -229,8 +261,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
|
||||
std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
|
||||
@@ -47,7 +47,7 @@ struct Add
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) + x1;
|
||||
y = x0 + type_convert<float>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
@@ -22,9 +22,10 @@ namespace ck {
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM kernel is carrying out following mathematical equation:
|
||||
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations that could be applied on each tensor respectively.
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
@@ -36,18 +37,20 @@ namespace ck {
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
@@ -105,11 +108,12 @@ namespace ck {
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
@@ -123,15 +127,17 @@ namespace ck {
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -161,8 +167,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -173,15 +179,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -211,8 +219,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -223,15 +231,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -261,8 +271,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -297,17 +307,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeCGridDescriptor_M_N;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumDTensor;
|
||||
using typename Base::DsGridPointer;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
|
||||
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << "}" << std::endl;
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
|
||||
<< ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
|
||||
<< "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
|
||||
<< ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__device__ static index_t GetKBlockPerScale() { return 1; }
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
@@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
@@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -20,15 +20,17 @@ namespace ck {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockN, // scale N
|
||||
@@ -60,11 +62,11 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
@@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using Base::CalculateNPadded;
|
||||
using Base::MakeAGridDescriptor_AK0_M_AK1;
|
||||
using Base::MakeBGridDescriptor_BK0_N_BK1;
|
||||
using Base::MakeCGridDescriptor_M_N;
|
||||
using Base::MakeDEGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_M_N;
|
||||
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
|
||||
|
||||
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
|
||||
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
|
||||
|
||||
using Base::NumDTensor;
|
||||
using typename Base::DsGridPointer;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
@@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
@@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
|
||||
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
|
||||
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
|
||||
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
|
||||
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
@@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleB_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const BScaleType* p_b_scale_grid;
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CElementwiseOperation c_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// B Scale grid
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
@@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(b_scale_struct),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
@@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum>
|
||||
__device__ static void
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
|
||||
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
|
||||
{
|
||||
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
@@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
@@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -92,8 +94,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -112,6 +114,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
@@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
template <typename DELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n[i], MBlock, NBlock);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
@@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
|
||||
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
|
||||
{
|
||||
if(!karg.IsReduceAdd())
|
||||
if(karg.IsAtomicAdd() && karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
|
||||
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename BScaleStruct,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
@@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
@@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
c_element_op};
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
@@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
@@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -165,6 +165,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
|
||||
// TODO: With column-major matrices this step restricts the transferred tensor slice
|
||||
// to just one element, which consequently prevents using atomic operations if the
|
||||
// matrix data type is on 16 bits.
|
||||
if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
|
||||
@@ -270,8 +270,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = false,
|
||||
bool neg_b = false,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
@@ -390,8 +390,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC,
|
||||
bool neg_a = false,
|
||||
bool neg_b = false,
|
||||
bool neg_a = true,
|
||||
bool neg_b = true,
|
||||
bool clamp = false>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
@@ -793,6 +793,9 @@ struct WmmaGemm
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
|
||||
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
|
||||
// Integer wmma operators need extra input flags to indicate if the input is signed or
|
||||
// unsigned. At the moment CK supports only signed integer inputs, so these flags are
|
||||
// hardcoded.
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
|
||||
Reference in New Issue
Block a user