This commit is contained in:
mtgu0705
2025-04-21 10:51:37 +08:00
parent 066f209640
commit 2f6529dcc2
3 changed files with 93 additions and 17 deletions

View File

@@ -149,19 +149,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, 128, 128,
32, 128, 128,
16, 16,
32, 32,
2, 2,
1, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>,
1, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
#endif
// clang-format on
@@ -180,11 +180,11 @@ int main(int argc, char* argv[])
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 19;
ck::index_t valid_tile_num = 16;
ck::index_t valid_tile_num = 2;
ck::index_t sorted_tile_num = valid_tile_num + 3;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 832;
ck::index_t tokens = 1;
ck::index_t topk = 2;
if(argc == 1)
@@ -232,8 +232,9 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 3, 3, 3};
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
@@ -269,7 +270,7 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N},
{(Scale_Stride_B * Scale_Stride_BN), Scale_Stride_BN, 1}));
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
@@ -322,12 +323,12 @@ int main(int argc, char* argv[])
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k_k.savetxt("a.txt");
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
a0_t_k_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// d0_t_n.savetxt("d0_t_n.txt", "int");
// d1_e_n.savetxt("d1_e_n.txt", "int");
// d2_e_n.savetxt("d2_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
@@ -381,6 +382,71 @@ int main(int argc, char* argv[])
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
#if 1
// printf the input tensor
// printf a tensor
printf("a0_t_k_k: \n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
printf("topk: %d: ", tk);
for(int k = 0; k < K; ++k)
{
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k)));
}
printf("\n");
}
}
// printf a scale tensor
printf("a1_t_k_k: \n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
printf("topk: %d: ", tk);
for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k)
{
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
}
printf("\n");
}
}
// printf b tensor
// printf("b0_e_n_k: \n");
// for (int e=0; e < experts; ++e)
// {
// for (int k=0; k < K; ++k)
// {
// printf("expert: %d: ", e);
// for (int n=0; n < N; ++n)
// {
// printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n)));
// }
// printf("\n");
// }
// }
// printf b scale tensor
printf("b1_e_n_k: \n");
for(int e = 0; e < experts; ++e)
{
for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k)
{
printf("expert: %d: ", e);
for(int n = 0; n < (N + Scale_Block_N - 1) / Scale_Block_N; ++n)
{
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
}
printf("\n");
}
}
#endif
if(time_kernel)
{
// not result correct here because output buf not setzero

View File

@@ -1107,7 +1107,7 @@ struct GridwiseMoeGemmBlockScale
}
// check gridwise gemm pipeline
#if 1
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
@@ -1373,14 +1373,15 @@ struct GridwiseMoeGemmBlockScale
// get each thread's offset in the scale tensor
// A scale
const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockK + a_thread_offset;
const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_scale_pos + m0];
index_t token_offset = fused_token & 0xffffff;
const index_t fused_token =
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl + a_thread_offset];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
@@ -1389,6 +1390,10 @@ struct GridwiseMoeGemmBlockScale
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
});
// printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id,
// threadIdx.x, a_thread_offset,
// scale_gather_offsets(Number<0>{}));
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2_gather<AScaleType,
AScaleType,

View File

@@ -580,6 +580,11 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
});
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
// blockIdx.y,
// threadIdx.x,
// dst_buf(Number<0>{}));
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{