mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
moe fp8 blockscale use nt (#3524)
* nt on fp8 blockscale
* some improve and tests needs to be fixed
* update
* fix format
* revert useless change
* revert any change in amd_buffer_coherence
[ROCm/composable_kernel commit: 32408c8bc0]
This commit is contained in:
@@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
|
||||
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
|
||||
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
@@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
|
||||
#else
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
@@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -182,12 +184,14 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t N = 1536;
|
||||
ck::index_t K = 4096;
|
||||
// ck::index_t N = 4096;
|
||||
// ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
ck::index_t experts = 16;
|
||||
ck::index_t topk = 8;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 208;
|
||||
@@ -196,9 +200,9 @@ int main(int argc, char* argv[])
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 2;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t tokens = 32;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 16;
|
||||
ck::index_t tokens = 4;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
@@ -209,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
ck::index_t KBatch = 6;
|
||||
ck::index_t KBatch = 1;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
|
||||
@@ -80,7 +80,8 @@ template <typename ALayout,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
typename LDSTypeB = ComputeTypeB,
|
||||
bool NonTemporalLoadB = false>
|
||||
struct DeviceMoeGemmBlockScale
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
@@ -163,7 +164,8 @@ struct DeviceMoeGemmBlockScale
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
LDSTypeB,
|
||||
NonTemporalLoadB>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
|
||||
@@ -173,7 +173,8 @@ template <typename ALayout,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
typename LDSTypeB = BDataType>
|
||||
typename LDSTypeB = BDataType,
|
||||
bool NonTemporalLoadB = false>
|
||||
struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
using AScaleType = float;
|
||||
@@ -1202,6 +1203,13 @@ struct GridwiseMoeGemmBlockScale
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
constexpr auto b_coherence_flag = NonTemporalLoadB
|
||||
? AmdBufferCoherenceEnum::WAVE_NT1
|
||||
: AmdBufferCoherenceEnum::DefaultCoherence;
|
||||
#else
|
||||
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
|
||||
#endif
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
@@ -1300,15 +1308,16 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
@@ -1465,9 +1474,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
const auto b_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid_up +
|
||||
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
@@ -1485,9 +1496,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
@@ -1958,6 +1970,13 @@ struct GridwiseMoeGemmBlockScale
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
constexpr auto b_coherence_flag = NonTemporalLoadB
|
||||
? AmdBufferCoherenceEnum::WAVE_NT1
|
||||
: AmdBufferCoherenceEnum::DefaultCoherence;
|
||||
#else
|
||||
constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence;
|
||||
#endif
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
@@ -2054,15 +2073,16 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
@@ -2227,9 +2247,11 @@ struct GridwiseMoeGemmBlockScale
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
const auto b_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_grid_up +
|
||||
expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
@@ -2247,9 +2269,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf_up =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, b_coherence_flag>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
auto b_scale_thread_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
|
||||
Reference in New Issue
Block a user