mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-22 08:07:38 +00:00
fix vec size error
This commit is contained in:
@@ -1440,16 +1440,16 @@ template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using ADataType = pk_fp4_t;
|
||||
using BDataType = pk_fp4_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 32>;
|
||||
using BVecType = ext_vector_t<BDataType, 32>;
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
@@ -1482,9 +1482,8 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base
|
||||
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t> && std::is_same_v<BDataType, pk_fp4_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale);
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 4, 4, opselA, a_scale, opselB, b_scale);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1512,11 +1511,6 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base
|
||||
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 =
|
||||
WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_base<pk_fp4_t, pk_fp4_t, Ctrl_>;
|
||||
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user