mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
add printf info
This commit is contained in:
@@ -177,7 +177,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -185,14 +185,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
ScaleBlockSize, 128,
|
||||
MPerBlock, 64, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
2, 8,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
|
||||
4, 2,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 8, 1, 16>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 256;
|
||||
ck::index_t tokens = 208;
|
||||
ck::index_t topk = 8;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 4;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -414,6 +414,24 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
|
||||
#if 1
|
||||
printf("a0_t_k_k:\n");
|
||||
|
||||
for(int t=0;t<tokens;t++)
|
||||
{
|
||||
for(int tk=0;tk<topk;tk++)
|
||||
{
|
||||
for(int k = 0; k < K;)
|
||||
{
|
||||
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&a0_t_k_k(t,tk,k))));
|
||||
k += 8;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
|
||||
@@ -626,7 +626,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
});
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
@@ -772,11 +774,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
|
||||
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
|
||||
});
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
@@ -797,7 +800,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
@@ -943,6 +945,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
#if 1
|
||||
printf("blkIdx: %u, blkIdy: %u, tidx: %u, im_minor: %d, in_minor: "
|
||||
"%d, ik_minor: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, "
|
||||
"0x%08x>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
im_minor,
|
||||
in_minor,
|
||||
ik_minor,
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(&(a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user