mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
Merge branch 'develop' into zan_fix_bufferloadlds
This commit is contained in:
@@ -244,12 +244,6 @@
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_TEMP_DISABLE_FP4_TESTS 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -600,12 +600,12 @@ struct Tensor
|
||||
ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
|
||||
}
|
||||
|
||||
@@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -0,0 +1,759 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument
|
||||
karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = batch;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// @brief \"Universal\" Batched GEMM operation without SplitK support.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations applied to the A, B, and C tensors, respectively.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through its design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation currently does not support the SplitK algorithm.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled). Currently not supported!
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
|
||||
// permuteB arguments so for now we are not including this functionality.
|
||||
static_assert(PermuteA == false,
|
||||
"Permute A functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
static_assert(PermuteB == false,
|
||||
"Permute B functionality not supported by DeviceBatchedGemm operations.\n");
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
false>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
|
||||
// Argument
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t BatchStrideA_,
|
||||
index_t BatchStrideB_,
|
||||
index_t BatchStrideC_,
|
||||
index_t Batch_,
|
||||
index_t k_batch_,
|
||||
bool is_reduce_ = false)
|
||||
: GridwiseGemm::Argument(p_a_grid_,
|
||||
p_b_grid_,
|
||||
p_c_grid_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
k_batch_,
|
||||
is_reduce_),
|
||||
Batch(Batch_),
|
||||
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
|
||||
{
|
||||
}
|
||||
|
||||
index_t Batch;
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
|
||||
};
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
/// @note If appropriately configured it may measure kernel execution time.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
// The normal approach to batching would be to increase the grid size by just stretching
|
||||
// out the grid Z dimension (which is the outermost dimension), but this depends on
|
||||
// lower level functions not directly using the Z dimension for other calculations. As
|
||||
// it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
|
||||
// Therefore, for now we will use the grid Y dimension for batching. This may be a bit
|
||||
// fragile.
|
||||
gdy *= arg.Batch;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
// Packed sizes are 1 for all implemented data types but we include it anyway
|
||||
// for future compatibility.
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
// Note: the grid descriptors and size_a / size_b do *not* take batching into
|
||||
// account, so we have to manually multiply overall buffer sizes for rotating
|
||||
// memory by batch.
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.Batch * size_a_buffer,
|
||||
arg_.Batch * size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_,
|
||||
arg_.compute_ptr_offset_of_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto clear_workspace = [&]() {
|
||||
// clear c mem
|
||||
if(arg.KBatch > 1)
|
||||
// Note: we multiply by batch since we want to clear the C matrix for
|
||||
// the whole batch. Untested since we don't have k batching ATM.
|
||||
// Note: This seems incorrect for non-contiguous memory layouts for C
|
||||
// (padding, gaps).
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.Batch * arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg,
|
||||
arg.compute_ptr_offset_of_batch);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Implement
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
|
||||
std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
// TODO: This is not part of the DeviceBatchedGemm base class but it was part of
|
||||
// DeviceBatchedGemmV2. Remove?
|
||||
// index_t GetKPerBlock() override { return KPerBlock; }
|
||||
// bool GetPermuteA() override { return PermuteA; }
|
||||
// bool GetPermuteB() override { return PermuteB; }
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1 /* KBatch */};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t Batch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
Batch,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
bool isMultiB,
|
||||
bool CTranspose>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -171,17 +172,22 @@ __global__ void
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
const long_index_t a_group_offset =
|
||||
CTranspose
|
||||
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
|
||||
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_n_offset =
|
||||
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
|
||||
p_as_grid + a_group_offset + a_n_offset,
|
||||
p_bs_grid + b_group_offset,
|
||||
p_bs_grid + b_group_offset + b_n_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr bool isATensorColMajor =
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
|
||||
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
|
||||
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool NeedTransposeKernel =
|
||||
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
|
||||
|
||||
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
|
||||
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
true /*SplitN*/,
|
||||
ADataType,
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
NumGroupsToMerge,
|
||||
index_t,
|
||||
CTranspose>;
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKYXC,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::GKZYXC,
|
||||
BLay>>;
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
|
||||
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
|
||||
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
constexpr auto matrix_padder_trans =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
|
||||
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
|
||||
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
|
||||
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
|
||||
BComputeDataType, DoElementwiseBeforeCShuffle
|
||||
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = std::conditional_t<
|
||||
isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
using GridwiseGemmCTranspose = std::conditional_t<
|
||||
CTranspose,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
|
||||
GridwiseGemm>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
|
||||
EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
a_g_n_c_wis_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
|
||||
: a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
b_g_k_c_xs_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
|
||||
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
|
||||
: b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
e_g_n_k_wos_strides_{NeedTransposeKernel
|
||||
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
|
||||
: e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
block_2_etile_map_{
|
||||
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_groups_{},
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
bool valid = false;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
|
||||
a_grid_desc_m_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_);
|
||||
}
|
||||
if(valid)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemmCTranspose,
|
||||
const BDataType*,
|
||||
const ADataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
BElementwiseOperation,
|
||||
AElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_b_grid,
|
||||
p_a_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.b_element_op_,
|
||||
arg.a_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB,
|
||||
CTranspose>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t a_grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
NeedTransposeKernel)
|
||||
{
|
||||
// Check access per C
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
static_assert(NeedTransposeKernel == false);
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
|
||||
if constexpr(ABlockTransferSrcScalarPerVector != 1)
|
||||
{
|
||||
if(ABlockTransferSrcVectorDim != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of Ds
|
||||
bool valid = true;
|
||||
|
||||
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
if constexpr(NeedTransposeKernel)
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
if(CTranspose == false)
|
||||
{
|
||||
return false;
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
? p_sorted_weights_0[IsInputGemm
|
||||
? token_offset
|
||||
: token_offset *
|
||||
problem.TopK +
|
||||
(fused_token >>
|
||||
24)]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
@@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
? p_sorted_weights_0[IsInputGemm
|
||||
? token_offset
|
||||
: token_offset *
|
||||
problem.TopK +
|
||||
(fused_token >>
|
||||
24)]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
typename IndexType = index_t,
|
||||
bool CTranspose = false>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
private:
|
||||
@@ -1253,6 +1254,83 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 1 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCHW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename ck::enable_if<NDimSpatial == 3 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGCDHW>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeADescriptor_M_K() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
|
||||
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKCX> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKCZYX>,
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeBDescriptor_N_K() const
|
||||
{
|
||||
static_assert(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
|
||||
ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0);
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, C_));
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
|
||||
@@ -1338,8 +1416,16 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1350,8 +1436,16 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -1362,12 +1456,21 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStrideTensorC_, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(I0, KStrideTensorC_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
@@ -1375,6 +1478,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1429,6 +1533,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1486,7 +1591,7 @@ struct TransformConvFwdToGemm
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
|
||||
static_assert(CTranspose == false);
|
||||
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
@@ -1536,6 +1641,101 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 1 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_wo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_wo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(n_k_wo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 2 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKHW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_howo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_howo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Ho_ * Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_howo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
index_t NDimSp = NDimSpatial,
|
||||
|
||||
typename ck::enable_if<NDimSp == 3 &&
|
||||
(is_same_v<CLayout, tensor_layout::convolution::GNKDHW> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NGKDHW>),
|
||||
bool>::type = false>
|
||||
__host__ __device__ auto MakeCDescriptor_M_N() const
|
||||
{
|
||||
static_assert(NumGroupsToMerge == 1);
|
||||
auto n_k_dohowo_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
|
||||
|
||||
if constexpr(CTranspose)
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_dohowo_desc,
|
||||
make_tuple(make_pass_through_transform(K_),
|
||||
make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
n_k_dohowo_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
IndexType N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
|
||||
@@ -5,19 +5,16 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
#ifdef CK_USE_FNUZ_FP8
|
||||
#define CK_USE_FNUZ_FP8 1
|
||||
#else
|
||||
#ifndef CK_USE_FNUZ_FP8
|
||||
#define CK_USE_FNUZ_FP8 0
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_OCP_FP8
|
||||
#define CK_USE_OCP_FP8 1
|
||||
#else
|
||||
#ifndef CK_USE_OCP_FP8
|
||||
#define CK_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
@@ -431,7 +428,7 @@ __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
|
||||
namespace fp8_impl {
|
||||
|
||||
// Assertions to check for supported conversion types
|
||||
#define __assert_ocp_support(interp) \
|
||||
#define __fp8_impl_assert_ocp_support(interp) \
|
||||
{ \
|
||||
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
|
||||
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
|
||||
@@ -439,7 +436,7 @@ namespace fp8_impl {
|
||||
__hip_assert(false && "type is unsupported by current target device"); \
|
||||
} \
|
||||
}
|
||||
#define __assert_fnuz_support(interp) \
|
||||
#define __fp8_impl_assert_fnuz_support(interp) \
|
||||
{ \
|
||||
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
|
||||
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
|
||||
@@ -453,10 +450,10 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
|
||||
#if CK_USE_OCP_FP8
|
||||
__assert_ocp_support(interp);
|
||||
__fp8_impl_assert_ocp_support(interp);
|
||||
#endif
|
||||
#if CK_USE_FNUZ_FP8
|
||||
__assert_fnuz_support(interp);
|
||||
__fp8_impl_assert_fnuz_support(interp);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
@@ -1396,12 +1393,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
@@ -1416,12 +1419,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
|
||||
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
|
||||
@@ -1487,12 +1496,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
@@ -1532,12 +1547,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
@@ -1574,12 +1595,18 @@ __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
@@ -1616,13 +1643,19 @@ __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
@@ -1664,14 +1697,20 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
|
||||
66
include/ck/utility/filter_tuple.hpp
Normal file
66
include/ck/utility/filter_tuple.hpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck::util {
|
||||
|
||||
template <typename Tuple, std::size_t Stride, std::size_t Offset>
|
||||
struct filter_tuple_by_modulo
|
||||
{
|
||||
// Validate Stride and Offset.
|
||||
static_assert(Stride > 0, "Offset must be positive.");
|
||||
static_assert(Offset >= 0 && Offset < Stride,
|
||||
"Offset must be positive and less than the stride.");
|
||||
|
||||
// Generate filtered indices for this stride and offset.
|
||||
static constexpr int new_size = (std::tuple_size_v<Tuple> + Stride - Offset - 1) / Stride;
|
||||
|
||||
template <std::size_t... Is>
|
||||
static constexpr auto to_index(std::index_sequence<Is...>)
|
||||
{
|
||||
return std::index_sequence<(Offset + Is * Stride)...>{};
|
||||
}
|
||||
|
||||
using filtered_indices = decltype(to_index(std::make_index_sequence<new_size>{}));
|
||||
|
||||
// Helper struct to construct the new tuple type from the filtered indices.
|
||||
template <typename T, typename Indices>
|
||||
struct make_filtered_tuple_type_impl;
|
||||
|
||||
template <typename T, std::size_t... Is>
|
||||
struct make_filtered_tuple_type_impl<T, std::index_sequence<Is...>>
|
||||
{
|
||||
using type = std::tuple<std::tuple_element_t<Is, T>...>;
|
||||
};
|
||||
|
||||
using type = typename make_filtered_tuple_type_impl<Tuple, filtered_indices>::type;
|
||||
};
|
||||
|
||||
// Filter a tuple with a stride and offset.
|
||||
//
|
||||
// Tuple is a std::tuple or equivalent
|
||||
// Stride is a positive integer
|
||||
// Offset is a positive integer smaller than ofset
|
||||
//
|
||||
// Evaluates to a smaller tuple type from elements of T with stride M and offset I.
|
||||
//
|
||||
// Can be used to filter a tuple of types for sharded instantiations.
|
||||
template <typename Tuple, std::size_t Stride, std::size_t Offset>
|
||||
using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo<Tuple, Stride, Offset>::type;
|
||||
|
||||
// Example compile-time test:
|
||||
// using OriginalTuple =
|
||||
// std::tuple<int, double, char, float, long, short, bool, char, long long, unsigned int>;
|
||||
// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t<OriginalTuple, 3, 1>;
|
||||
// static_assert(std::is_same_v<NewTuple_Every3rdFrom2nd, std::tuple<double, long, char>>,
|
||||
// "Test Case 1 Failed: Every 3rd from 2nd");
|
||||
|
||||
} // namespace ck::util
|
||||
@@ -197,8 +197,9 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const fl
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
}
|
||||
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
|
||||
}
|
||||
@@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
}
|
||||
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/mxf4_utils.hpp"
|
||||
#include "ck/utility/mxf6_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
@@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
|
||||
// convert fp32 to fp4 with stochastic rounding
|
||||
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
@@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
value.bitwise, float_values.float2_array, rng, scale, 0);
|
||||
return value.f4_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -1475,30 +1491,26 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
// convert vector of 2 fp32 to vector of 2 fp4 with sr
|
||||
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
// apply a temporary workaround for gfx950
|
||||
#if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
|
||||
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
value.bitwise = (h << 4) | l;
|
||||
#else
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
|
||||
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
@@ -1514,13 +1526,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
// convert vector of 32 fp32 to vector of 32 fp4 with sr
|
||||
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
__uint128_t bitwise;
|
||||
@@ -1546,6 +1555,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
union
|
||||
{
|
||||
__uint128_t bitwise;
|
||||
@@ -1776,13 +1791,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
|
||||
*/
|
||||
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -1799,6 +1811,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
|
||||
return out.f6_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -1815,6 +1833,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
*/
|
||||
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
{
|
||||
@@ -1828,9 +1852,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
#else
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -2044,13 +2065,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
|
||||
*/
|
||||
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -2067,6 +2085,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
|
||||
return out.bf6_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -2085,6 +2109,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
*/
|
||||
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
{
|
||||
@@ -2098,9 +2128,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
#else
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using int32_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
|
||||
@@ -1009,6 +1009,15 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
@@ -1031,6 +1040,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
|
||||
{
|
||||
@@ -1041,6 +1052,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
|
||||
{
|
||||
@@ -1051,6 +1064,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
@@ -1061,6 +1076,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
|
||||
@@ -314,8 +314,7 @@ struct tile_window_linear
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto dst_tensor =
|
||||
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
@@ -348,8 +347,9 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
};
|
||||
|
||||
@@ -400,8 +400,9 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
};
|
||||
|
||||
@@ -805,8 +806,7 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -861,8 +861,7 @@ struct tile_window_linear
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
|
||||
@@ -230,7 +230,7 @@ struct HostTensorDescriptor
|
||||
* @param iss Vector containing the multi-dimensional indices
|
||||
* @return The calculated linear offset as a size_t
|
||||
*/
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -540,9 +540,12 @@ struct HostTensor
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
@@ -719,6 +722,8 @@ struct HostTensor
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
else if(dtype == "int")
|
||||
file << type_convert<int>(itm) << std::endl;
|
||||
else if(dtype == "int8_t")
|
||||
file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
|
||||
else
|
||||
// TODO: we didn't implement operator<< for all custom
|
||||
// data types, here fall back to float in case compile error
|
||||
|
||||
@@ -75,7 +75,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x32) || defined(USING_MFMA_32x32x16)
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -91,64 +90,68 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
#endif
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
#endif
|
||||
if constexpr(WG::kM == 16 && WG::kN == 16)
|
||||
{
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
else if constexpr(WG::kM == 32 && WG::kN == 32 &&
|
||||
(A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
|
||||
@@ -19,55 +19,61 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
@@ -138,6 +144,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(TileShape::idxK) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16);
|
||||
return TileShape::WarpTile::at(TileShape::idxK) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -189,7 +210,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
@@ -232,19 +253,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad =
|
||||
Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
|
||||
@@ -316,56 +316,56 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -468,51 +468,51 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
|
||||
@@ -808,6 +808,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
@@ -847,7 +848,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
0, // min_seqlen_q
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
@@ -890,6 +891,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
@@ -929,6 +931,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = WarpSize * NumWarps;
|
||||
static constexpr index_t BlockSize = get_warp_size() * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
|
||||
@@ -388,7 +388,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -625,7 +625,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
|
||||
static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
@@ -693,7 +693,7 @@ struct MoeSortingKernel
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size());
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
@@ -798,7 +798,7 @@ struct MoeSortingKernel
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += WarpSize)
|
||||
for(; i_e_ < num_experts; i_e_ += get_warp_size())
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
@@ -843,7 +843,7 @@ struct MoeSortingKernel
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, WarpSize>(local_cumsum_);
|
||||
wave_cumsum<int, get_warp_size()>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
@@ -851,7 +851,7 @@ struct MoeSortingKernel
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, WarpSize>(local_masking);
|
||||
wave_cumsum<int, get_warp_size()>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
@@ -861,7 +861,7 @@ struct MoeSortingKernel
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - WarpSize) == (num_experts - 1))
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
}
|
||||
@@ -1109,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -1504,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1546,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1660,7 +1660,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1786,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1801,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1880,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1905,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -1951,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2083,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -2110,8 +2110,8 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
}();
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2141,17 +2141,17 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2196,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2303,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -2356,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2441,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2518,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2569,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2624,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
|
||||
@@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
|
||||
@@ -59,14 +59,23 @@ struct GemmHostArgs
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* e_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
@@ -282,4 +282,19 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
// int8
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1578,8 +1578,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
@@ -1609,6 +1609,183 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 64;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec =
|
||||
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
#undef DISPATCH_MFMA_
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace ck_tile {
|
||||
namespace impl {
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -22,6 +22,7 @@ struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
@@ -37,10 +38,12 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
@@ -56,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
@@ -81,12 +85,19 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -95,7 +106,7 @@ template <typename AType,
|
||||
bool UseStructuredSparsity = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
|
||||
@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
Reference in New Issue
Block a user