mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] support split-k a16w4 gemm1 (#3389)
* initial version to support moe gemm1 split-k * add missing args * fix build warning * update reference * for split-k disable bias and weight * remove debug log * fix format * fix div by zero errors * fix cmake config * update * resolve conflicts * remove useless changes * reformat * fix * remove useless changes * fix ci --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
@@ -28,17 +28,18 @@ struct FlatmmProblem
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
template <int SharedGranularityMN, typename ScaleType_>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
|
||||
{
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
template <typename ScaleType_>
|
||||
struct FlatmmScalePointer<-1, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
const float* ptr = nullptr;
|
||||
const ScaleType* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
|
||||
@@ -132,6 +132,7 @@ enum class MoeFlatmmKind
|
||||
kFFN_gemm1_gate_only,
|
||||
kFFN_gemm1_gate_up,
|
||||
kFFN_gemm2,
|
||||
kFFN_gemm1_split_k,
|
||||
};
|
||||
|
||||
namespace moe {
|
||||
@@ -222,8 +223,10 @@ struct MoeFlatmmKernel
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k;
|
||||
static constexpr bool IsBShuffled = true;
|
||||
|
||||
// static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
|
||||
@@ -395,15 +398,6 @@ struct MoeFlatmmKernel
|
||||
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
@@ -412,6 +406,22 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
||||
}
|
||||
|
||||
if constexpr(IsBShuffled)
|
||||
{
|
||||
b_k_split_offset = k_id * splitted_k * NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
@@ -573,15 +583,16 @@ struct MoeFlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = IsInputGemm ? memory_operation_enum::set
|
||||
: memory_operation_enum::atomic_add,
|
||||
template <memory_operation_enum DstInMemOp = (IsInputGemm && !IsGemm1SplitK)
|
||||
? memory_operation_enum::set
|
||||
: memory_operation_enum::atomic_add,
|
||||
typename KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
EDataType* e_ptr,
|
||||
[[maybe_unused]] const AccDataType* exp_weight_ptr,
|
||||
const int expert_id,
|
||||
[[maybe_unused]] const int expert_id,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
@@ -742,13 +753,13 @@ struct MoeFlatmmKernel
|
||||
{
|
||||
index_t scale_k =
|
||||
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
|
||||
const auto scale_k_offset =
|
||||
(splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = e8m0_t;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
|
||||
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
@@ -1386,11 +1397,16 @@ struct MoeFlatmmKernel
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *=
|
||||
epi_scale_m[idx] * epi_scale_n[idx];
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
|
||||
else // for mlp1 gate-only
|
||||
if(kind !=
|
||||
MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k
|
||||
{
|
||||
if constexpr(EnableBias)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
|
||||
}
|
||||
if constexpr(kind ==
|
||||
MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only
|
||||
lds_tile[lds_stage].get_thread_buffer()[idx] =
|
||||
ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user