mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] support split-k a16w4 gemm1 (#3389)
* initial version to support moe gemm1 split-k
* add missing args
* fix build warning
* update reference
* for split-k disable bias and weight
* remove debug log
* fix format
* fix div by zero errors
* fix cmake config
* update
* resolve conflicts
* remove useless changes
* reformat
* fix
* remove useless changes
* fix ci
---------
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
[ROCm/composable_kernel commit: dae85ead64]
This commit is contained in:
@@ -28,17 +28,18 @@ struct FlatmmProblem
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
template <int SharedGranularityMN, typename ScaleType_>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
const float* ptr;
|
||||
const ScaleType* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
|
||||
{
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
template <typename ScaleType_>
|
||||
struct FlatmmScalePointer<-1, 0, ScaleType_>
|
||||
{
|
||||
using ScaleType = ScaleType_;
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
const float* ptr = nullptr;
|
||||
const ScaleType* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user