mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit '87dd073887933fc2c75c234871e3885cee970a98' into develop
This commit is contained in:
@@ -0,0 +1,764 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_multi_d_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using EDataType = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}(
|
||||
[&](auto i) { splitk_batch_offset.a_k_split_offset[i] += a_batch_offset; });
|
||||
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}(
|
||||
[&](auto i) { splitk_batch_offset.b_k_split_offset[i] += b_batch_offset; });
|
||||
|
||||
splitk_batch_offset.c_reduce_offset += c_batch_offset;
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
// D pointer
|
||||
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
|
||||
: public DeviceBatchedGemmV2MultiD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
|
||||
using CDataType_ = EDataType;
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
ComputePtrOffsetOfStridedBatch(
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideA_) * g_idx;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideB_) * g_idx;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
std::array<long_index_t, GridwiseGemm::NumDTensor> ds_offset_;
|
||||
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset_[i] = static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
|
||||
});
|
||||
|
||||
return ds_offset_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(BatchStrideC_) * g_idx;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
std::array<ck::index_t, GridwiseGemm::NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
|
||||
Argument() = default;
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t BatchStrideA_,
|
||||
index_t BatchStrideB_,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs_,
|
||||
index_t BatchStrideE_,
|
||||
index_t Batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
index_t KBatch_)
|
||||
: GridwiseGemm::Argument{std::array<const void*, 1>{p_a_grid_},
|
||||
std::array<const void*, 1>{p_b_grid_},
|
||||
p_ds_grid_,
|
||||
p_e_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
std::array<index_t, 1>{StrideA_},
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
KBatch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
cde_element_op_,
|
||||
false},
|
||||
Batch{Batch_},
|
||||
compute_ptr_offset_of_batch{
|
||||
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
|
||||
{
|
||||
}
|
||||
template <typename EType>
|
||||
void SetEPointer(void* ptr)
|
||||
{
|
||||
this->p_e_grid = static_cast<EType*>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
int max_occupancy = 0;
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_multi_d_wmma_cshuffle_v3<GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
gdy *= arg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[0] = arg_.Batch *
|
||||
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[0] = arg_.Batch *
|
||||
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_,
|
||||
arg_.compute_ptr_offset_of_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg,
|
||||
arg.compute_ptr_offset_of_batch);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported: Architecture must be gfx11/gfx12." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported splitK on gfx11." << std::endl;
|
||||
}
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported f8 / bf8 on gfx11." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported K dimension without padding." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, GridwiseGemm::NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& StrideDs,
|
||||
index_t StrideE,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
const std::array<ck::index_t, GridwiseGemm::NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideDs,
|
||||
BatchStrideE,
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultipleD_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(ELayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static ck::index_t GetMaxOccupancy()
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
return active_workgroups_per_cu.max_occupancy_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -350,6 +350,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
|
||||
{
|
||||
}
|
||||
template <typename EType>
|
||||
void SetEPointer(void* ptr)
|
||||
{
|
||||
this->p_c_grid = static_cast<EType*>(ptr);
|
||||
}
|
||||
};
|
||||
using Argument = ArgumentBase<GridwiseGemm64>;
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -807,7 +808,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -844,9 +845,10 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, // input
|
||||
@@ -915,7 +917,6 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
|
||||
@@ -32,7 +32,7 @@ template <ck::index_t NDimSpatial,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
typename DeviceGemmV3Op>
|
||||
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
struct DeviceGroupedConvBwdWeight_Explicit
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
@@ -56,7 +56,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
sizeof(WeiDataType) % 4 != 0 &&
|
||||
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit;
|
||||
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
|
||||
|
||||
static constexpr index_t ElementwiseBlockSize = 256;
|
||||
@@ -95,7 +95,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
using GemmArgument = typename DeviceGemmV3Op::Argument;
|
||||
|
||||
@@ -153,11 +153,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -170,12 +170,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
split_k_ = split_k;
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k_};
|
||||
k_batch_};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -236,7 +236,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k_};
|
||||
k_batch_};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
ck::index_t split_k_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -288,8 +287,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
// Modify to use workspace as output
|
||||
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
|
||||
explicit_gemm_args_with_workspace.p_c_grid =
|
||||
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
|
||||
explicit_gemm_args_with_workspace.template SetEPointer<TwoStageIntermediateType>(
|
||||
arg.p_workspace_);
|
||||
float avg_time =
|
||||
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
|
||||
const index_t grid_size =
|
||||
@@ -342,7 +341,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.split_k_ < 0)
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -353,6 +352,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -360,11 +363,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if constexpr(!is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported layout." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -374,6 +385,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported stride / pad." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -381,6 +396,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
{
|
||||
if(!arg.is_filter_data_packed)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Unsupported: Filter data must be packed." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1745,6 +1745,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// TODO: this is needed because there is a bug
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -450,7 +451,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
@@ -490,8 +491,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
@@ -504,6 +504,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
k_batch_ = split_k;
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
@@ -576,7 +578,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -295,7 +295,7 @@ struct ABTransferThreadTiles
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
const index_t k_id)
|
||||
{
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
const index_t mn_block_data_idx_on_grid =
|
||||
@@ -304,7 +304,7 @@ struct ABTransferThreadTiles
|
||||
if constexpr(NumABTensor > 1)
|
||||
{
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
|
||||
[&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); },
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
@@ -357,7 +357,7 @@ struct ABTransferThreadTiles
|
||||
ABThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
GlobalBufferNum>(grid_descriptor[I0],
|
||||
make_multi_index(0, mn_block_data_idx_on_grid, 0),
|
||||
make_multi_index(k_id, mn_block_data_idx_on_grid, 0),
|
||||
ab_element_op,
|
||||
block_descriptor,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
@@ -409,6 +410,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument() = default;
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
@@ -583,7 +585,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -651,7 +654,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -700,7 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
Argument& karg,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -735,7 +740,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
@@ -748,20 +754,146 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(
|
||||
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id);
|
||||
EpilogueArgument>(p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
DefaultBlock2CTileMap(karg),
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
}
|
||||
|
||||
// Run method for convolution (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
index_t NumGroupsToMerge,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -723,7 +723,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -793,7 +794,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
@@ -806,7 +808,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -857,7 +860,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
epilogue_args,
|
||||
0, /* A_k_id == 0 (we shift the pointer for splitk) */
|
||||
k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -344,11 +349,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
@@ -706,8 +720,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
ReduceTrait>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n,
|
||||
index_t MBlock,
|
||||
index_t NBlock)
|
||||
{
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
@@ -1004,6 +1020,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
// Note: arguments k_batch and k_id should be set if splitk is used
|
||||
// with implicit gemm (no pointer shift but shift using tensor descriptors)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -1034,7 +1052,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -1066,7 +1086,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AsDataType,
|
||||
AElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id);
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, A_k_id);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
@@ -1075,7 +1095,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BsDataType,
|
||||
BElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id);
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, B_k_id);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1100,7 +1120,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
|
||||
@@ -71,6 +71,29 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float4_t atomic_add<float4_t>(float4_t* p_dst, const float4_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const vector_type<float, 4> vx{x};
|
||||
vector_type<float, 4> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
vy.template AsType<float>()(I2) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, vx.template AsType<float>()[I2]);
|
||||
vy.template AsType<float>()(I3) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, vx.template AsType<float>()[I3]);
|
||||
|
||||
return vy.template AsType<float4_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user