fix performance bug of bpreshuffle f8 gemm

This commit is contained in:
aska-0096
2025-05-29 10:02:46 +00:00
parent c3d52993c4
commit d563dac424
2 changed files with 4 additions and 6 deletions

View File

@@ -139,13 +139,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
256, 256, 128,
16, 16,
16, 16,
8, 2,
16, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
2, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on

View File

@@ -168,9 +168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr auto is_scale_mfma = false;