This commit is contained in:
lalala-sh
2025-07-23 15:01:53 +08:00
parent 46a538e39e
commit 7e1bd4b839
3 changed files with 285 additions and 29 deletions

View File

@@ -11,12 +11,97 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
template <index_t NumDTensor = 0>
struct FlatmmHostArgs
struct FlatmmProblem
{
CK_TILE_HOST FlatmmHostArgs() = default;
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
CK_TILE_HOST FlatmmProblem() = default;
CK_TILE_HOST FlatmmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <int SharedGranularity>
struct FlatmmScalePointer
{
static constexpr int granularity = SharedGranularity;
union
{
const float* ptr;
float scalar; // if shared granularity is 0, all rows/columns use the same scale value
};
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(granularity == 0)
{
ret.scalar = scalar;
}
else if constexpr(granularity == 1)
{
ret.ptr = ptr + offset;
}
else
{
ret.ptr = ptr + offset / granularity;
}
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
{
if constexpr(granularity == 0)
{
return scalar;
}
else if constexpr(granularity == 1)
{
return ptr[i];
}
else
{
return ptr[i / granularity];
}
}
};
// shared granularity = -1 means no scale
template <>
struct FlatmmScalePointer<-1>
{
static constexpr int granularity = -1;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float scalar_) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float* ptr_) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}
};
template <>
struct BaseFlatmmHostArgs
{
CK_TILE_HOST BaseFlatmmHostArgs() = default;
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
@@ -66,7 +151,37 @@ struct FlatmmHostArgs
index_t k_batch;
};
template <index_t NumDTensor = 0>
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>, index_t NumDTensor = 0>
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
{
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
const void* b_shuffle_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_, k_batch_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <int NumberTensor=0>
using FlatmmHostArgs = ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
struct FlatmmKernelArgs
{
const void* a_ptr;
@@ -82,6 +197,8 @@ struct FlatmmKernelArgs
std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
index_t k_batch;
ScaleM scale_m_ptr = nullptr;
ScaleN scale_n_ptr = nullptr;
};
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
@@ -113,7 +230,7 @@ struct FlatmmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -129,21 +246,24 @@ struct FlatmmKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
MakeKernelArgs(const FlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
{
return KernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
return {hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch,
hostArgs.scale_m,
hostArgs.scale_n};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
@@ -157,8 +277,8 @@ struct FlatmmKernel
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) {
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
@@ -196,6 +316,7 @@ struct FlatmmKernel
index_t splitted_k;
};
template <class KernelArgs>
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
@@ -341,7 +462,7 @@ struct FlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
@@ -559,14 +680,14 @@ struct FlatmmKernel
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
}
template <bool UseDefaultScheduler = true>
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const KernelArgs& kargs,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -588,8 +709,18 @@ struct FlatmmKernel
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if(UseDefaultScheduler || (get_warp_id() == 0))
if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
@@ -598,7 +729,9 @@ struct FlatmmKernel
}
}
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);