add printf info

This commit is contained in:
mtgu0705
2025-10-31 05:01:10 -05:00
parent 3d0f3abf65
commit 6769664197
2 changed files with 51 additions and 15 deletions

View File

@@ -177,7 +177,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 64;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -185,14 +185,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 64,
MPerBlock, 128, KPerBlock,
ScaleBlockSize, 128,
MPerBlock, 64, KPerBlock,
16, 16,
16, 16,
2, 8,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
4, 2,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 8, 1, 16>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 256;
ck::index_t tokens = 208;
ck::index_t topk = 8;
ck::index_t K = 256;
ck::index_t experts = 8;
ck::index_t tokens = 4;
ck::index_t topk = 2;
if(argc == 1)
{
@@ -414,6 +414,24 @@ int main(int argc, char* argv[])
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
#if 1
printf("a0_t_k_k:\n");
for(int t=0;t<tokens;t++)
{
for(int tk=0;tk<topk;tk++)
{
for(int k = 0; k < K;)
{
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&a0_t_k_k(t,tk,k))));
k += 8;
}
printf("\n");
}
printf("\n");
}
#endif
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{

View File

@@ -626,7 +626,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
@@ -772,11 +774,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
@@ -797,7 +800,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
@@ -943,6 +945,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
#if 1
printf("blkIdx: %u, blkIdy: %u, tidx: %u, im_minor: %d, in_minor: "
"%d, ik_minor: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, "
"0x%08x>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
im_minor,
in_minor,
ik_minor,
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))));
#endif
});
});
});