Merge branch 'develop' into zan_fix_bufferloadlds

This commit is contained in:
zanzhang
2025-06-27 14:33:46 +08:00
122 changed files with 6086 additions and 2232 deletions

View File

@@ -244,12 +244,6 @@
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
// workaround: compiler issue on gfx950
#define CK_TEMP_DISABLE_FP4_TESTS 1
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1

View File

@@ -167,7 +167,7 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
@@ -600,12 +600,12 @@ struct Tensor
ck::packed_size_v<ck::remove_cvref_t<T>>];
}
T& operator()(std::vector<std::size_t> idx)
T& operator()(const std::vector<std::size_t>& idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
const T& operator()(std::vector<std::size_t> idx) const
const T& operator()(const std::vector<std::size_t>& idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}

View File

@@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
@@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
@@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
// B VGPR->VGPR dequant
@@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -0,0 +1,759 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument
karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = batch;
ignore = compute_ptr_offset_of_batch;
#endif
}
/// @brief \"Universal\" Batched GEMM operation without SplitK support.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations applied to the A, B, and C tensors, respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through its design
/// and versatilty.
///
/// @note This Kernel implementation currently does not support the SplitK algorithm.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam CDataType C tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled). Currently not supported!
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
// permuteB arguments so for now we are not including this functionality.
static_assert(PermuteA == false,
"Permute A functionality not supported by DeviceBatchedGemm operations.\n");
static_assert(PermuteB == false,
"Permute B functionality not supported by DeviceBatchedGemm operations.\n");
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false, // PermuteA not supported by DeviceBatchedGemm base class.
false>; // PermuteB not supported by DeviceBatchedGemm base class.
// Argument
struct Argument : public GridwiseGemm::Argument
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t BatchStrideA_,
index_t BatchStrideB_,
index_t BatchStrideC_,
index_t Batch_,
index_t k_batch_,
bool is_reduce_ = false)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
k_batch_,
is_reduce_),
Batch(Batch_),
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
{
}
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
};
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
// The normal approach to batching would be to increase the grid size by just stretching
// out the grid Z dimension (which is the outermost dimension), but this depends on
// lower level functions not directly using the Z dimension for other calculations. As
// it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
// Therefore, for now we will use the grid Y dimension for batching. This may be a bit
// fragile.
gdy *= arg.Batch;
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
// Packed sizes are 1 for all implemented data types but we include it anyway
// for future compatibility.
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
// Note: the grid descriptors and size_a / size_b do *not* take batching into
// account, so we have to manually multiply overall buffer sizes for rotating
// memory by batch.
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.Batch * size_a_buffer,
arg_.Batch * size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
// Note: we multiply by batch since we want to clear the C matrix for
// the whole batch. Untested since we don't have k batching ATM.
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg_.p_c_grid,
0,
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_,
arg_.compute_ptr_offset_of_batch);
}
else
{
auto clear_workspace = [&]() {
// clear c mem
if(arg.KBatch > 1)
// Note: we multiply by batch since we want to clear the C matrix for
// the whole batch. Untested since we don't have k batching ATM.
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg.p_c_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg,
arg.compute_ptr_offset_of_batch);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
GridwiseGemm,
ComputePtrOffsetOfStridedBatch,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
GridwiseGemm,
remove_reference_t<ComputePtrOffsetOfStridedBatch>,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
// TODO: This is not part of the DeviceBatchedGemm base class but it was part of
// DeviceBatchedGemmV2. Remove?
// index_t GetKPerBlock() override { return KPerBlock; }
// bool GetPermuteA() override { return PermuteA; }
// bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1 /* KBatch */};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1); // KBatch
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceBatchedGemm_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -77,7 +77,8 @@ template <typename GridwiseGemm,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
bool isMultiB,
bool CTranspose>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -171,17 +172,22 @@ __global__ void
}
else
{
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_group_offset =
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
p_as_grid + a_group_offset + a_n_offset,
p_bs_grid + b_group_offset,
p_bs_grid + b_group_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset,
p_shared,
@@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr bool isATensorColMajor =
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
(ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool NeedTransposeKernel =
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>());
static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
(is_same_v<ELayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGKDHW>);
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
NumGroupsToMerge>;
NumGroupsToMerge,
index_t,
CTranspose>;
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
@@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGC,
ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
@@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NHWGK,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
ctc::NDHWGK,
ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
if constexpr(CTranspose)
{
constexpr auto matrix_padder_trans =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
else
{
return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
}
}
// Shape of Ds and E must be aligned. Strides can be different.
@@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
#define GridwiseGemmCTransposeTemplateParameters \
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
using GridwiseGemm = std::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
GridwiseGemm>;
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
using APointers =
@@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
remove_cvref_t<decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(
EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
using NGCHWTransposeDescType =
@@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
a_g_n_c_wis_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
: a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
b_g_k_c_xs_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
: b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
e_g_n_k_wos_strides_{NeedTransposeKernel
? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
: e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
@@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
block_2_etile_map_{
GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
@@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
bool valid = false;
if constexpr(CTranspose)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
a_grid_desc_m_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
else
{
valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_);
}
if(valid)
{
e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_);
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
}
}
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
@@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
if constexpr(CTranspose)
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemmCTranspose,
const BDataType*,
const ADataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
BElementwiseOperation,
AElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_b_grid,
p_a_grid,
arg.p_ds_grid_,
p_e_grid,
arg.b_element_op_,
arg.a_element_op_,
arg.cde_element_op_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
else
{
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
GridwiseGemm,
const ADataType*,
const BDataType*,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB,
CTranspose>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
}
};
@@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
float avg_time = 0.f;
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
@@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
@@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
// check device
if(get_device_name() == "gfx908")
@@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
NeedTransposeKernel)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
@@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else if constexpr(is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
{
static_assert(NeedTransposeKernel == false);
static_assert(NumGroupsToMerge == 1);
if constexpr(ABlockTransferSrcScalarPerVector != 1)
{
if(ABlockTransferSrcVectorDim != 1)
{
return false;
}
if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
}
else
{
return false;
@@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
// check vector access of Ds
bool valid = true;
@@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(NeedTransposeKernel)
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
@@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
@@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
if(CTranspose == false)
{
return false;
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
}
else
@@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
if constexpr(CTranspose)
{
return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_,
arg.a_grid_desc_m_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
else
{
return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
}
}

View File

@@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else
@@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else

View File

@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
typename IndexType = index_t,
bool CTranspose = false>
struct TransformConvFwdToGemm
{
private:
@@ -1253,6 +1254,83 @@ struct TransformConvFwdToGemm
}
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 1 &&
is_same_v<ALayout, tensor_layout::convolution::NGCW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::NGCHW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename ALayout,
typename ck::enable_if<NDimSpatial == 3 &&
is_same_v<ALayout, tensor_layout::convolution::NGCDHW>,
bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const
{
static_assert(NumGroupsToMerge == 1);
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0);
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename BLayout,
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKCX> ||
is_same_v<BLayout, tensor_layout::convolution::GKCYX> ||
is_same_v<BLayout, tensor_layout::convolution::GKCZYX>,
bool>::type = false>
__host__ __device__ auto MakeBDescriptor_N_K() const
{
static_assert(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0);
static_assert(NumGroupsToMerge == 1);
return make_naive_tensor_descriptor_packed(make_tuple(K_, C_));
}
template <typename BLayout,
typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
@@ -1338,8 +1416,16 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <
@@ -1350,8 +1436,16 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <
@@ -1362,12 +1456,21 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
if constexpr(CTranspose)
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStrideTensorC_, I0));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
@@ -1375,6 +1478,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1429,6 +1533,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1486,7 +1591,7 @@ struct TransformConvFwdToGemm
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(CTranspose == false);
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
{
@@ -1536,6 +1641,101 @@ struct TransformConvFwdToGemm
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_wo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_wo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(n_k_wo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 2 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKHW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKHW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_howo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_howo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Ho_ * Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
n_k_howo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
template <typename CLayout,
index_t NDimSp = NDimSpatial,
typename ck::enable_if<NDimSp == 3 &&
(is_same_v<CLayout, tensor_layout::convolution::GNKDHW> ||
is_same_v<CLayout, tensor_layout::convolution::NGKDHW>),
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const
{
static_assert(NumGroupsToMerge == 1);
auto n_k_dohowo_desc = make_naive_tensor_descriptor(
make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1));
if constexpr(CTranspose)
{
return transform_tensor_descriptor(
n_k_dohowo_desc,
make_tuple(make_pass_through_transform(K_),
make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
n_k_dohowo_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
IndexType N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;

View File

@@ -5,19 +5,16 @@
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/type.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#ifndef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#ifndef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 0
#endif
@@ -431,7 +428,7 @@ __host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
namespace fp8_impl {
// Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \
#define __fp8_impl_assert_ocp_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
@@ -439,7 +436,7 @@ namespace fp8_impl {
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
#define __assert_fnuz_support(interp) \
#define __fp8_impl_assert_fnuz_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
@@ -453,10 +450,10 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
#if CK_USE_OCP_FP8
__assert_ocp_support(interp);
__fp8_impl_assert_ocp_support(interp);
#endif
#if CK_USE_FNUZ_FP8
__assert_fnuz_support(interp);
__fp8_impl_assert_fnuz_support(interp);
#endif
#endif
}
@@ -1396,12 +1393,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
@@ -1416,12 +1419,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
@@ -1487,12 +1496,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
@@ -1532,12 +1547,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
@@ -1574,12 +1595,18 @@ __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
@@ -1616,13 +1643,19 @@ __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,
@@ -1664,14 +1697,20 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x[0]));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
static_cast<float>(x[0]));
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <tuple>
#include <type_traits>
#include <utility>
#include "ck/utility/functional.hpp"
#include "ck/utility/sequence.hpp"
namespace ck::util {
template <typename Tuple, std::size_t Stride, std::size_t Offset>
struct filter_tuple_by_modulo
{
// Validate Stride and Offset.
static_assert(Stride > 0, "Offset must be positive.");
static_assert(Offset >= 0 && Offset < Stride,
"Offset must be positive and less than the stride.");
// Generate filtered indices for this stride and offset.
static constexpr int new_size = (std::tuple_size_v<Tuple> + Stride - Offset - 1) / Stride;
template <std::size_t... Is>
static constexpr auto to_index(std::index_sequence<Is...>)
{
return std::index_sequence<(Offset + Is * Stride)...>{};
}
using filtered_indices = decltype(to_index(std::make_index_sequence<new_size>{}));
// Helper struct to construct the new tuple type from the filtered indices.
template <typename T, typename Indices>
struct make_filtered_tuple_type_impl;
template <typename T, std::size_t... Is>
struct make_filtered_tuple_type_impl<T, std::index_sequence<Is...>>
{
using type = std::tuple<std::tuple_element_t<Is, T>...>;
};
using type = typename make_filtered_tuple_type_impl<Tuple, filtered_indices>::type;
};
// Filter a tuple with a stride and offset.
//
// Tuple is a std::tuple or equivalent
// Stride is a positive integer
// Offset is a positive integer smaller than ofset
//
// Evaluates to a smaller tuple type from elements of T with stride M and offset I.
//
// Can be used to filter a tuple of types for sharded instantiations.
template <typename Tuple, std::size_t Stride, std::size_t Offset>
using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo<Tuple, Stride, Offset>::type;
// Example compile-time test:
// using OriginalTuple =
// std::tuple<int, double, char, float, long, short, bool, char, long long, unsigned int>;
// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t<OriginalTuple, 3, 1>;
// static_assert(std::is_same_v<NewTuple_Every3rdFrom2nd, std::tuple<double, long, char>>,
// "Test Case 1 Failed: Every 3rd from 2nd");
} // namespace ck::util

View File

@@ -197,8 +197,9 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const fl
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
@@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}

View File

@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
@@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <>
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
#if defined(__gfx94__)
union
{
@@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
#if defined(__gfx94__)
union
{
@@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
uint32_t bitwise;
@@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
@@ -1475,30 +1491,26 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
// apply a temporary workaround for gfx950
#if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
value.bitwise = (h << 4) | l;
#else
// permute high bits and low bits to match the order of the original vector
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
return value.f4x2_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
union
{
uint32_t bitwise;
@@ -1514,13 +1526,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
__uint128_t bitwise;
@@ -1546,6 +1555,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
return f4_values.f4x32_array;
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
union
{
__uint128_t bitwise;
@@ -1776,13 +1791,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
*/
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
float32_t float_vector;
@@ -1799,6 +1811,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
return out.f6_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
@@ -1815,6 +1833,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
constexpr int seed = 1254739;
union
{
@@ -1828,9 +1852,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
#endif
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
@@ -2044,13 +2065,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
*/
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
float32_t float_vector;
@@ -2067,6 +2085,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
return out.bf6_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
@@ -2085,6 +2109,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
constexpr int seed = 1254739;
union
{
@@ -2098,9 +2128,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
#endif
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;

View File

@@ -7,6 +7,7 @@
namespace ck_tile {
using index_t = int32_t;
using int32_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;

View File

@@ -1009,6 +1009,15 @@ struct buffer_view<address_space_enum::lds,
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
@@ -1031,6 +1040,8 @@ struct buffer_view<address_space_enum::lds,
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
{
@@ -1041,6 +1052,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
{
@@ -1051,6 +1064,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
{
@@ -1061,6 +1076,8 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
[&](auto ii) {
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
: static_cast<index_t>(idx_y_start[ii]);
},
number<NDimY>{});
constexpr auto idx_y_out =

View File

@@ -314,8 +314,7 @@ struct tile_window_linear
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor =
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
@@ -348,8 +347,9 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() =
vec_value
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
});
};
@@ -400,8 +400,9 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() =
vec_value
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
});
};
@@ -805,8 +806,7 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -861,8 +861,7 @@ struct tile_window_linear
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Base::Traits::PackedSize;
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});

View File

@@ -230,7 +230,7 @@ struct HostTensorDescriptor
* @param iss Vector containing the multi-dimensional indices
* @return The calculated linear offset as a size_t
*/
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
@@ -540,9 +540,12 @@ struct HostTensor
return mData[GetOffsetFromMultiIndex(is...)];
}
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
T& operator()(const std::vector<std::size_t>& idx)
{
return mData[GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const
const T& operator()(const std::vector<std::size_t>& idx) const
{
return mData[GetOffsetFromMultiIndex(idx)];
}
@@ -719,6 +722,8 @@ struct HostTensor
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else if(dtype == "int8_t")
file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error

View File

@@ -75,7 +75,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
#if defined(USING_MFMA_16x16x32) || defined(USING_MFMA_32x32x16)
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
@@ -91,64 +90,68 @@ struct FlatmmPipelineAGmemBGmemCRegV1
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
#endif
#if defined(USING_MFMA_16x16x32)
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
#elif defined(USING_MFMA_32x32x16)
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
#endif
if constexpr(WG::kM == 16 && WG::kN == 16)
{
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
}
else if constexpr(WG::kM == 32 && WG::kN == 32 &&
(A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
{
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
}
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>

View File

@@ -19,55 +19,61 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
#if defined(USING_MFMA_16x16x32)
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
if constexpr(MPerXdl == 16 && NPerXdl == 16)
{
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
return a_lds_block_desc;
#elif defined(USING_MFMA_32x32x16)
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
return a_lds_block_desc;
}
else
{
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
return a_lds_block_desc;
#endif
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
/*xor*/
#if 0
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
@@ -138,6 +144,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32)
{
return TileShape::WarpTile::at(TileShape::idxK) / 2;
}
else
{
static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16);
return TileShape::WarpTile::at(TileShape::idxK) / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
@@ -189,7 +210,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
@@ -232,19 +253,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad =
Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;

View File

@@ -316,56 +316,56 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -468,51 +468,51 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,

View File

@@ -808,6 +808,7 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
@@ -847,7 +848,7 @@ struct FmhaFwdKernel
window_size_left,
window_size_right,
mask_type,
0, // min_seqlen_q
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
@@ -890,6 +891,7 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple<const void*, const void*>& drop_seed_offset)
@@ -929,6 +931,7 @@ struct FmhaFwdKernel
window_size_left,
window_size_right,
mask_type,
min_seqlen_q,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 2, 1>, // N0 K2 N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 1, 2>, // N0 K2 <-> N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();

View File

@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = WarpSize * NumWarps;
static constexpr index_t BlockSize = get_warp_size() * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);

View File

@@ -388,7 +388,7 @@ struct MoeSortingKernel
}
// reduce single pixel within a wave
template <typename T, typename F, index_t wave_size_ = WarpSize>
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -625,7 +625,7 @@ struct MoeSortingKernel
{
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset = cumsum[eid] +
@@ -693,7 +693,7 @@ struct MoeSortingKernel
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size());
const index_t lid = __lane_id();
constexpr index_t block_size = 256; // blockDim.x;
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
@@ -798,7 +798,7 @@ struct MoeSortingKernel
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += WarpSize)
for(; i_e_ < num_experts; i_e_ += get_warp_size())
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
@@ -843,7 +843,7 @@ struct MoeSortingKernel
// cumsum padded in case local cumsum is zero, but
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, WarpSize>(local_cumsum_);
wave_cumsum<int, get_warp_size()>(local_cumsum_);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
@@ -851,7 +851,7 @@ struct MoeSortingKernel
if constexpr(Problem::LocalExpertMasking)
{
local_masking += pre_cumsum_masking;
wave_cumsum<int, WarpSize>(local_masking);
wave_cumsum<int, get_warp_size()>(local_masking);
if((i_e_ + lid) < num_experts)
smem_cumdup(i_e_ + lid + 1) = local_masking;
}
@@ -861,7 +861,7 @@ struct MoeSortingKernel
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
}
if((lid + i_e_ - WarpSize) == (num_experts - 1))
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
{
*p_total_tokens_post_pad = local_cumsum_;
}
@@ -1109,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
return chunk * sizeof(index_t);
};
template <typename T, typename F, index_t wave_size_ = WarpSize>
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
{
// constexpr int wave_size = 64;
@@ -1504,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1546,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1560,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
{
c += s[i];
}
@@ -1660,7 +1660,7 @@ struct MoeSortingMultiPhaseKernel_P01
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize()
{
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -1786,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
}
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % get_warp_size();
index_t wave_id = threadIdx.x / get_warp_size();
// reduce cross wave
IndexType* s = reinterpret_cast<IndexType*>(smem);
@@ -1801,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01
if(threadIdx.x == 0)
{
index_t c = 0;
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
{
c += s[i];
}
@@ -1880,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
// return 2 * BLOCK_SIZE * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
}
// reduce single pixel within a wave
@@ -1905,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -1951,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2083,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3
// in byte
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
{
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
@@ -2110,8 +2110,8 @@ struct MoeSortingMultiPhaseKernel_P3
}
}();
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int wave_id = threadIdx.x / get_warp_size();
int lane_id = threadIdx.x % get_warp_size();
int e_start = p_expert_cumsum[eid];
int e_end = p_expert_cumsum[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2141,17 +2141,17 @@ struct MoeSortingMultiPhaseKernel_P3
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2196,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
{
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
const index_t expert_cumsum_elem = num_experts_ + 1;
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
}
} // namespace impl
@@ -2303,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23
const IndexType* p_local_expert_mask =
static_cast<const IndexType*>(kargs.p_local_expert_mask);
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
IndexType* p_total_tokens_post_pad =
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
index_t wave_id = threadIdx.x / WarpSize;
index_t lane_id = threadIdx.x % WarpSize;
index_t wave_id = threadIdx.x / get_warp_size();
index_t lane_id = threadIdx.x % get_warp_size();
IndexType prev_cumsum_a = 0;
IndexType prev_cumsum_b = 0;
@@ -2356,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType cumsum_b = b_;
// Note: we first cumsum local round, then add previous cumsum
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
s[4 + wave_id] = cumsum_a;
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev_a = s[4 + i_w];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
prev_a = wave_id > i_w ? prev_a : 0; // mask out
prev_b = wave_id > i_w ? prev_b : 0; // mask out
cumsum_a += prev_a;
@@ -2441,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23
IndexType* s = reinterpret_cast<IndexType*>(smem);
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
int eid = blockIdx.x;
int wave_id = threadIdx.x / WarpSize;
int lane_id = threadIdx.x % WarpSize;
int wave_id = threadIdx.x / get_warp_size();
int lane_id = threadIdx.x % get_warp_size();
int e_start = p_expert_cumsum_smem[eid];
int e_end = p_expert_cumsum_smem[eid + 1];
if constexpr(Problem::SkipExpertsWithZeroTokens)
@@ -2518,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk = x - 1; // topk of this token
int i_show = x != 0 ? 1 : 0; // has this token or not
int cumsum = i_show;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2569,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23
cumsum_store += i_show[j];
});
int cumsum = cumsum_store;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;
@@ -2624,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23
int i_topk_1 = x1 - 1; // topk of this token
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
int cumsum = i_show_0 + i_show_1;
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
__syncthreads();
if(lane_id == WarpSize - 1)
if(lane_id == get_warp_size() - 1)
{
s[4 + wave_id] = cumsum;
}
__syncthreads();
// reduce cross wave
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
IndexType prev = s[4 + i_w];
prev = wave_id > i_w ? prev : 0; // mask out
cumsum += prev;

View File

@@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
ALdsTile b_warp_tile_;
BLdsTile b_warp_tile_;
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>

View File

@@ -59,14 +59,23 @@ struct GemmHostArgs
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* e_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};

View File

@@ -172,7 +172,7 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
#else
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
@@ -282,4 +282,19 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
2,
swizzle_factor>>;
// int8
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
} // namespace ck_tile

View File

@@ -1578,8 +1578,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
@@ -1609,6 +1609,183 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 32;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
else
{
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 64;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 16;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
else
{
#if defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 16>;
using BVecType = ext_vector_t<BDataType, 16>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 32;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 16;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
else
{
#if defined(__gfx95__)
c_vec =
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
#undef DISPATCH_MFMA_
} // namespace ck_tile

View File

@@ -11,7 +11,7 @@ namespace ck_tile {
namespace impl {
template <typename AType,
typename BType,
typename CType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
@@ -22,6 +22,7 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
@@ -37,10 +38,12 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// fp16 2:4 structural sparsity
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
// bf16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
@@ -56,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
@@ -81,12 +85,19 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
// int8
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
// clang-format on
} // namespace impl
template <typename AType,
typename BType,
typename CType,
typename AccType,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
@@ -95,7 +106,7 @@ template <typename AType,
bool UseStructuredSparsity = false>
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
CType,
AccType,
MPerWave,
NPerWave,
KPerWave,

View File

@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * 4 * thread_buf_size * sizeof(float);
}
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
const index_t smem_offset = warp_id;
// skip if nonthing to do