moe fp8 blockscale use nt (#3524)

* nt on fp8 blockscale

* some improve and tests needs to be fixed

* update

* fix format

* revert useless change

* revert any change in amd_buffer_coherence
This commit is contained in:
yadaish
2026-01-12 10:48:10 +08:00
committed by GitHub
parent 4216d43da8
commit 32408c8bc0
3 changed files with 63 additions and 34 deletions

View File

@@ -80,7 +80,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
typename LDSTypeB = ComputeTypeB,
bool NonTemporalLoadB = false>
struct DeviceMoeGemmBlockScale
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
@@ -163,7 +164,8 @@ struct DeviceMoeGemmBlockScale
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
LDSTypeB,
NonTemporalLoadB>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;

View File

@@ -173,7 +173,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
typename LDSTypeB = BDataType,
bool NonTemporalLoadB = false>
struct GridwiseMoeGemmBlockScale
{
using AScaleType = float;
@@ -1202,6 +1203,13 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
: AmdBufferCoherenceEnum::DefaultCoherence;
#else
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
#endif
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
@@ -1300,15 +1308,16 @@ struct GridwiseMoeGemmBlockScale
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
@@ -1465,9 +1474,11 @@ struct GridwiseMoeGemmBlockScale
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
@@ -1485,9 +1496,10 @@ struct GridwiseMoeGemmBlockScale
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
@@ -1958,6 +1970,13 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
#if defined(__gfx942__) || defined(__gfx950__)
constexpr auto b_coherence_flag = NonTemporalLoadB
? AmdBufferCoherenceEnum::WAVE_NT1
: AmdBufferCoherenceEnum::DefaultCoherence;
#else
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
#endif
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
@@ -2054,15 +2073,16 @@ struct GridwiseMoeGemmBlockScale
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
@@ -2227,9 +2247,11 @@ struct GridwiseMoeGemmBlockScale
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
const auto b_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_grid_up +
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
@@ -2247,9 +2269,10 @@ struct GridwiseMoeGemmBlockScale
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto b_scale_grid_buf_up =
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,