[CK_TILE] support split-k a16w4 gemm1 (#3389)

* initial version to support moe gemm1 split-k

* add missing args

* fix build warning

* update reference

* for split-k disable bias and weight

* remove debug log

* fix format

* fix div by zero errors

* fix cmake config

* update

* resolve conflicts

* remove useless changes

* reformat

* fix

* remove useless changes

* fix ci

---------

Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
yadaish
2025-12-29 23:05:35 +08:00
committed by GitHub
parent a0acc83a72
commit dae85ead64
11 changed files with 136 additions and 78 deletions

View File

@@ -28,17 +28,18 @@ struct FlatmmProblem
index_t stride_C;
};
template <int SharedGranularityMN, int SharedGranularityK = 0>
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
struct FlatmmScalePointer
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
const ScaleType* ptr;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
};
template <int SharedGranularityMN>
struct FlatmmScalePointer<SharedGranularityMN, 0>
template <int SharedGranularityMN, typename ScaleType_>
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const float* ptr;
const ScaleType* ptr;
index_t length;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
};
// shared granularityMN = -1 means no scale
template <>
struct FlatmmScalePointer<-1, 0>
template <typename ScaleType_>
struct FlatmmScalePointer<-1, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const float* ptr = nullptr;
const ScaleType* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}

View File

@@ -132,6 +132,7 @@ enum class MoeFlatmmKind
kFFN_gemm1_gate_only,
kFFN_gemm1_gate_up,
kFFN_gemm2,
kFFN_gemm1_split_k,
};
namespace moe {
@@ -222,8 +223,10 @@ struct MoeFlatmmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k;
static constexpr bool IsBShuffled = true;
// static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
@@ -395,15 +398,6 @@ struct MoeFlatmmKernel
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
@@ -412,6 +406,22 @@ struct MoeFlatmmKernel
{
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
}
if constexpr(IsBShuffled)
{
b_k_split_offset = k_id * splitted_k * NPerXdl;
}
else
{
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
}
}
index_t a_k_split_offset;
@@ -573,15 +583,16 @@ struct MoeFlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = IsInputGemm ? memory_operation_enum::set
: memory_operation_enum::atomic_add,
template <memory_operation_enum DstInMemOp = (IsInputGemm && !IsGemm1SplitK)
? memory_operation_enum::set
: memory_operation_enum::atomic_add,
typename KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
EDataType* e_ptr,
[[maybe_unused]] const AccDataType* exp_weight_ptr,
const int expert_id,
[[maybe_unused]] const int expert_id,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
@@ -742,13 +753,13 @@ struct MoeFlatmmKernel
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
const auto scale_k_offset =
(splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
using ScaleType = e8m0_t;
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
@@ -1386,11 +1397,16 @@ struct MoeFlatmmKernel
if constexpr(!BMXFP4_Pipeline)
lds_tile[lds_stage].get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
if constexpr(EnableBias)
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
if constexpr(!IsInputGemm)
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
else // for mlp1 gate-only
if(kind !=
MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k
{
if constexpr(EnableBias)
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
if constexpr(!IsInputGemm)
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
}
if constexpr(kind ==
MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only
lds_tile[lds_stage].get_thread_buffer()[idx] =
ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
});