mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
16x16 function merged to moe
This commit is contained in:
@@ -197,96 +197,28 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
|
||||
constexpr auto num_pk_fma_per_kscaleblock = MPerXDL == 16 ? 2 : 8;
|
||||
constexpr auto num_mfma_per_kscaleblock =
|
||||
MPerXDL == 16 ? KScaleBlock / 32 : KScaleBlock / 16;
|
||||
#if 0
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
/* Judging issue v_pk_fma */
|
||||
if constexpr((i + 1) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) %
|
||||
num_mfma_per_kscaleblock ==
|
||||
0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
#elif 1 // v_mul occured too early causing vmcnt stall
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
/* Judging issue v_pk_fma */
|
||||
if constexpr((i + 1) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
if constexpr((num_buffer_load_inst_b + 2 * i + 1) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
if constexpr((num_buffer_load_inst_b + 2 * i + 2) % num_mfma_per_kscaleblock == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
if constexpr((num_buffer_load_inst_b + 2 * num_buffer_load_inst_a + i + 1) %
|
||||
num_mfma_per_kscaleblock ==
|
||||
0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x800, num_pk_fma_per_kscaleblock, 0); // PK_FMA
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
@@ -426,12 +358,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
});
|
||||
});
|
||||
|
||||
// printf("blockIdx.y = %d, blockIdx.x = %d, threadIdx.x = %d, c_scale_thread_buf = <%f>\n",
|
||||
// blockIdx.y,
|
||||
// blockIdx.x,
|
||||
// threadIdx.x,
|
||||
// c_scale_thread_buf[Number<0>{}]);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
@@ -583,6 +509,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
@@ -600,19 +542,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
@@ -638,8 +567,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -401,16 +401,16 @@ struct DeviceMoeGemmBlockScale
|
||||
};
|
||||
#endif
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b =
|
||||
NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
|
||||
constexpr auto estimated_reg_c =
|
||||
MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
// constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
// 4 * (1 + GridwiseGemm::NWave);
|
||||
// constexpr auto estimated_reg_b =
|
||||
// NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
|
||||
// constexpr auto estimated_reg_c =
|
||||
// MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
|
||||
// constexpr auto estimated_reg_total =
|
||||
// estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
constexpr index_t minimum_occupancy = 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
@@ -704,7 +704,7 @@ struct DeviceMoeGemmBlockScale
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
index_t KBatch,
|
||||
// index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
@@ -727,7 +727,7 @@ struct DeviceMoeGemmBlockScale
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
1, //KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
@@ -749,7 +749,9 @@ struct DeviceMoeGemmBlockScale
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmm"
|
||||
|
||||
@@ -187,9 +187,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
@@ -249,7 +250,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KLane * KPack);
|
||||
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
@@ -391,7 +392,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1334,7 +1335,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1946,7 +1947,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2076,6 +2077,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -2371,6 +2373,45 @@ struct GridwiseMoeGemmBlockScale
|
||||
I0,
|
||||
cde_lds_and_global_step);
|
||||
}
|
||||
|
||||
// // print C
|
||||
// printf("tid: %d, blkid: %d, "
|
||||
// "c_thread_buf = <%1.f, %1.f, %1.f>\n "
|
||||
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f,"
|
||||
// // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n"
|
||||
// , get_thread_local_1d_id(), block_m_id,
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<0>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<1>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<2>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<4>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<5>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<6>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<7>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<8>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<9>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<10>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<11>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<12>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<13>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<14>{}],
|
||||
// c_thread_buf.GetVectorTypeReference(Number<0>{}) .template
|
||||
// AsType<AccDataType>()[Number<3>{}]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user