mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
updated
This commit is contained in:
@@ -149,19 +149,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, 128, 128,
|
||||
32, 128, 128,
|
||||
16, 16,
|
||||
32, 32,
|
||||
2, 2,
|
||||
1, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>,
|
||||
1, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
@@ -180,11 +180,11 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 19;
|
||||
ck::index_t valid_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t sorted_tile_num = valid_tile_num + 3;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t tokens = 1;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -232,8 +232,9 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int eids[] = {0, 1, 3, 3, 3};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
@@ -269,7 +270,7 @@ int main(int argc, char* argv[])
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
|
||||
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N},
|
||||
{(Scale_Stride_B * Scale_Stride_BN), Scale_Stride_BN, 1}));
|
||||
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
@@ -322,12 +323,12 @@ int main(int argc, char* argv[])
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
// a0_t_k_k.savetxt("a.txt");
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
a0_t_k_k.savetxt("a.txt");
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
// d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
// d2_e_n.savetxt("d2_e_n.txt", "int");
|
||||
d2_e_n.savetxt("d2_e_n.txt", "int");
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
@@ -381,6 +382,71 @@ int main(int argc, char* argv[])
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
#if 1
|
||||
// printf the input tensor
|
||||
// printf a tensor
|
||||
printf("a0_t_k_k: \n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
printf("topk: %d: ", tk);
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a0_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// printf a scale tensor
|
||||
printf("a1_t_k_k: \n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
printf("topk: %d: ", tk);
|
||||
for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// printf b tensor
|
||||
// printf("b0_e_n_k: \n");
|
||||
// for (int e=0; e < experts; ++e)
|
||||
// {
|
||||
// for (int k=0; k < K; ++k)
|
||||
// {
|
||||
// printf("expert: %d: ", e);
|
||||
// for (int n=0; n < N; ++n)
|
||||
// {
|
||||
// printf("%f ", ck::type_convert<float>(b0_e_n_k(e, k, n)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// }
|
||||
|
||||
// printf b scale tensor
|
||||
printf("b1_e_n_k: \n");
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k)
|
||||
{
|
||||
printf("expert: %d: ", e);
|
||||
for(int n = 0; n < (N + Scale_Block_N - 1) / Scale_Block_N; ++n)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(b1_e_n_k(e, k, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
|
||||
@@ -1107,7 +1107,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
#if 1
|
||||
#if 0
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
@@ -1373,14 +1373,15 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
// get each thread's offset in the scale tensor
|
||||
// A scale
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockK + a_thread_offset;
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
|
||||
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_scale_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl + a_thread_offset];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -1389,6 +1390,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
|
||||
});
|
||||
|
||||
// printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id,
|
||||
// threadIdx.x, a_thread_offset,
|
||||
// scale_gather_offsets(Number<0>{}));
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2_gather<AScaleType,
|
||||
AScaleType,
|
||||
|
||||
@@ -580,6 +580,11 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
});
|
||||
});
|
||||
|
||||
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
|
||||
// blockIdx.y,
|
||||
// threadIdx.x,
|
||||
// dst_buf(Number<0>{}));
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user