[CK_TILE] support split-k a16w4 gemm1 (#3389)

* initial version to support moe gemm1 split-k

* add missing args

* fix build warning

* update reference

* for split-k disable bias and weight

* remove debug log

* fix format

* fix div by zero errors

* fix cmake config

* update

* resolve conflicts

* remove useless changes

* reformat

* fix

* remove useless changes

* fix ci

---------

Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>

[ROCm/composable_kernel commit: dae85ead64]
This commit is contained in:
yadaish
2025-12-29 23:05:35 +08:00
committed by GitHub
parent 89e943a9f3
commit a57f8d8b67
11 changed files with 136 additions and 78 deletions

View File

@@ -28,17 +28,18 @@ struct FlatmmProblem
index_t stride_C;
};
template <int SharedGranularityMN, int SharedGranularityK = 0>
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
struct FlatmmScalePointer
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
const ScaleType* ptr;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
};
template <int SharedGranularityMN>
struct FlatmmScalePointer<SharedGranularityMN, 0>
template <int SharedGranularityMN, typename ScaleType_>
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const float* ptr;
const ScaleType* ptr;
index_t length;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
};
// shared granularityMN = -1 means no scale
template <>
struct FlatmmScalePointer<-1, 0>
template <typename ScaleType_>
struct FlatmmScalePointer<-1, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const float* ptr = nullptr;
const ScaleType* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}