debug block v3

This commit is contained in:
coderfeli
2025-02-20 07:05:48 +00:00
parent d5e395918e
commit b6fbe4f7fe
3 changed files with 40 additions and 10 deletions

View File

@@ -174,7 +174,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
4, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
2, 1, S<1, 16, 1, 16>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, Nswizzle, true, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -197,8 +197,6 @@ int main(int argc, char* argv[])
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 544;
ck::index_t topk = 2;
@@ -217,6 +215,17 @@ int main(int argc, char* argv[])
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9) {
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
@@ -227,6 +236,8 @@ int main(int argc, char* argv[])
exit(0);
}
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
@@ -287,8 +298,8 @@ int main(int argc, char* argv[])
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
@@ -296,12 +307,21 @@ int main(int argc, char* argv[])
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
// d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
// d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());

View File

@@ -2088,8 +2088,8 @@ struct GridwiseMoeGemm
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// if(threadIdx.x % 16 == 0 && blockIdx.x == 0)
// printf("init off bid %d tid %d m %d off %d wei %f\n", blockIdx.x, threadIdx.x, m0(), token_offset, weight);
if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});

View File

@@ -447,8 +447,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
dst_offset,
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(threadIdx.x==0 && blockIdx.x==0) {
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
@@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
auto adjusted_step_idx_scatter = [&]()
{
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
});
return step_;
}
();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}