mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
debug block v3
This commit is contained in:
@@ -174,7 +174,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
4, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
2, 1, S<1, 16, 1, 16>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, Nswizzle, true, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
@@ -197,8 +197,6 @@ int main(int argc, char* argv[])
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 544;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -217,6 +215,17 @@ int main(int argc, char* argv[])
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9) {
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
@@ -227,6 +236,8 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
if (tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
@@ -287,8 +298,8 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -296,12 +307,21 @@ int main(int argc, char* argv[])
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
// d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
// d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
|
||||
@@ -2088,8 +2088,8 @@ struct GridwiseMoeGemm
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
|
||||
// if(threadIdx.x % 16 == 0 && blockIdx.x == 0)
|
||||
// printf("init off bid %d tid %d m %d off %d wei %f\n", blockIdx.x, threadIdx.x, m0(), token_offset, weight);
|
||||
if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
|
||||
printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
@@ -447,8 +447,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
dst_offset,
|
||||
is_dst_valid,
|
||||
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(threadIdx.x==0 && blockIdx.x==0) {
|
||||
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
|
||||
@@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
auto adjusted_step_idx_scatter = [&]()
|
||||
{
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
|
||||
|
||||
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user