fix format error

This commit is contained in:
Gino Lu
2025-09-01 01:23:39 -05:00
parent b422e41e08
commit d2892925e5
2 changed files with 76 additions and 9 deletions

View File

@@ -275,7 +275,7 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp4<WGAttrCtlEnum::Default_>,
WarpGemmAttributeMfma<WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<

View File

@@ -1393,9 +1393,6 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 4, 0, 0, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1419,9 +1416,6 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 4, 0, 0, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
@@ -1446,9 +1440,82 @@ template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 32>;
using BVecType = ext_vector_t<BDataType, 32>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 128;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 32;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
// opsel, scale_b)
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
// c_vec = a_vec * b_vec
template <index_t opselA, index_t opselB>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
{
#if defined(__gfx950__)
if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
a_vec, b_vec, CVecType{0.f}, 4, 4, opselA, a_scale, opselB, b_scale));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp4 =
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<pk_fp4_t, pk_fp4_t, Ctrl_>;
using WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 =
WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base<pk_fp4_t, pk_fp4_t, Ctrl_>;
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base