mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
grouped topk debug
This commit is contained in:
@@ -441,8 +441,8 @@ struct BlockGemmSoftmaxGroupedTopkPipelineAGmemBGmemCReg
|
||||
auto p_compute =
|
||||
make_static_distributed_tensor<ComputeDataType>(c_block_tile.get_tile_distribution());
|
||||
|
||||
auto debug_block_tile =
|
||||
make_static_distributed_tensor<WeightType>(p_compute.get_tile_distribution());
|
||||
// auto debug_block_tile =
|
||||
// make_static_distributed_tensor<WeightType>(p_compute.get_tile_distribution());
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
|
||||
@@ -476,88 +476,136 @@ struct BlockGemmSoftmaxGroupedTopkPipelineAGmemBGmemCReg
|
||||
|
||||
// apply topk for softmax output
|
||||
auto x_tmp = p_compute;
|
||||
|
||||
// // // initialize x_tmp for topk debug
|
||||
// // printf("===============on device debug input=====================\n");
|
||||
// // std::mt19937 rng(123);
|
||||
// // std::uniform_int_distribution<int> dist_debug_input(1, 100);
|
||||
|
||||
// constexpr auto x_tmp_spans = decltype(x_tmp)::get_distributed_spans();
|
||||
// sweep_tile_span(x_tmp_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(x_tmp_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// x_tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// auto row_id = tile_idx.at(number<0>{});
|
||||
// auto col_id = tile_idx.at(number<1>{});
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// x_tmp(i_j_idx) = sin(float(row_id + col_id)) * 100;
|
||||
// // x_tmp(i_j_idx) = sin(float(row_id)) * cos(float(col_id));
|
||||
// // x_tmp(i_j_idx) = float(dist_debug_input(rng));
|
||||
// });
|
||||
// });
|
||||
|
||||
// argmax for topk
|
||||
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
|
||||
return e0.value > e1.value ? e0 : e1;
|
||||
};
|
||||
|
||||
for(index_t i_k = 0; i_k < topk; i_k++)
|
||||
{
|
||||
constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans();
|
||||
auto packet = [&]() {
|
||||
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(p_compute.get_tile_distribution());
|
||||
sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket t;
|
||||
t.value = x_tmp(i_j_idx); // !!! we reference p_compute here
|
||||
t.arg = tile_idx.at(number<1>{});
|
||||
tmp(i_j_idx) = t;
|
||||
});
|
||||
// calculate group score, need to creat group scores tensor
|
||||
int num_expert_group = 16;
|
||||
// int topk_group = 2;
|
||||
int expert_per_group = kNPerBlock / num_expert_group;
|
||||
constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans();
|
||||
auto group_scores = x_tmp;
|
||||
// init group_scores to inf
|
||||
sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
group_scores(i_j_idx) = -numeric<WeightType>::infinity();
|
||||
});
|
||||
});
|
||||
for (index_t n_group = 0; n_group < num_expert_group; n_group++) {
|
||||
// get group value matrix (masked other groups)
|
||||
auto group_tmp = x_tmp;
|
||||
sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
group_tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
group_tmp(i_j_idx) = ((col_id >= (n_group * expert_per_group)) && (col_id < ((n_group + 1) * expert_per_group))) ? x_tmp(i_j_idx) : -numeric<WeightType>::infinity();
|
||||
});
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
auto argmax_init = ArgmaxPacket{-numeric<WeightType>::infinity(), 0};
|
||||
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
|
||||
|
||||
block_tile_reduce_xor_sync(r, f_argmax);
|
||||
|
||||
// constexpr auto value_spans = decltype(value_block_tile)::get_distributed_spans();
|
||||
});
|
||||
// get one column for group scores = rowmax(group_tmp)
|
||||
auto group_scores_col = block_tile_reduce<ComputeDataType>(
|
||||
group_tmp, sequence<1>{}, f_max, std::numeric_limits<ComputeDataType>::lowest());
|
||||
|
||||
block_tile_reduce_sync(group_scores_col, f_max);
|
||||
// get all group scores
|
||||
sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// auto row_id = tile_idx.at(number<0>{});
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket tmp = r(i_idx);
|
||||
// debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.value: debug_block_tile(i_j_idx);
|
||||
debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.arg: debug_block_tile(i_j_idx);
|
||||
// value_block_tile(i_j_idx) = tmp.value;
|
||||
// index_block_tile(i_j_idx) = tmp.arg;
|
||||
});
|
||||
});
|
||||
|
||||
// update value
|
||||
sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
x_tmp(i_j_idx) = (col_id == r(i_j_idx).arg) ? -numeric<WeightType>::infinity()
|
||||
: x_tmp(i_j_idx);
|
||||
group_scores(i_j_idx) = (col_id == n_group) ? group_scores_col(i_idx): group_scores(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
return debug_block_tile;
|
||||
|
||||
// // another scheme to reshape x_tmp from 2d to 3d
|
||||
// const auto x_tmp_3d = x_tmp.block_tile_reshape((kM, num_expert_group, kN/num_expert_group));
|
||||
// auto group_scores = block_tile_reduce<ComputeDataType>(
|
||||
// x_tmp_3d, sequence<2>{}, f_max, std::numeric_limits<ComputeDataType>::lowest());
|
||||
// block_tile_reduce_sync(group_scores, f_max);
|
||||
|
||||
// // Step2: select group values and group_indices
|
||||
// // argmax for topk
|
||||
// const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
|
||||
// return e0.value > e1.value ? e0 : e1;
|
||||
// };
|
||||
// auto group_packet = topk(group_scores, topk_group)
|
||||
|
||||
// // Step3: mask score matrix
|
||||
// sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
// x_tmp(i_j_idx) = (col_id != group_packet(i_j_idx).arg) ? -numeric<WeightType>::infinity()
|
||||
// : x_tmp(i_j_idx);
|
||||
// });
|
||||
// });
|
||||
|
||||
// // Step4: select topk values from masked scores
|
||||
// for(index_t i_k = 0; i_k < topk; i_k++)
|
||||
// {
|
||||
// constexpr auto p_compute_spans = decltype(p_compute)::get_distributed_spans();
|
||||
// auto packet = [&]() {
|
||||
// auto tmp = make_static_distributed_tensor<ArgmaxPacket>(p_compute.get_tile_distribution());
|
||||
// sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// ArgmaxPacket t;
|
||||
// t.value = x_tmp(i_j_idx); // !!! we reference p_compute here
|
||||
// t.arg = tile_idx.at(number<1>{});
|
||||
// tmp(i_j_idx) = t;
|
||||
// });
|
||||
// });
|
||||
// return tmp;
|
||||
// }();
|
||||
|
||||
// auto argmax_init = ArgmaxPacket{-numeric<WeightType>::infinity(), 0};
|
||||
// auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
|
||||
|
||||
// block_tile_reduce_xor_sync(r, f_argmax);
|
||||
|
||||
// // constexpr auto value_spans = decltype(value_block_tile)::get_distributed_spans();
|
||||
|
||||
// sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
// constexpr auto i_idx = make_tuple(idx0);
|
||||
// sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// // auto row_id = tile_idx.at(number<0>{});
|
||||
// auto col_id = tile_idx.at(number<1>{});
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// ArgmaxPacket tmp = r(i_idx);
|
||||
// debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.value: debug_block_tile(i_j_idx);
|
||||
// // debug_block_tile(i_j_idx) = (col_id == i_k) ? tmp.arg: debug_block_tile(i_j_idx);
|
||||
// // value_block_tile(i_j_idx) = tmp.value;
|
||||
// // index_block_tile(i_j_idx) = tmp.arg;
|
||||
// });
|
||||
// });
|
||||
|
||||
// // update value
|
||||
// sweep_tile_span(p_compute_spans[number<0>{}], [&](auto idx0) {
|
||||
// constexpr auto i_idx = make_tuple(idx0);
|
||||
// sweep_tile_span(p_compute_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
// x_tmp(i_j_idx) = (col_id == r(i_idx).arg) ? -numeric<WeightType>::infinity()
|
||||
// : x_tmp(i_j_idx);
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
// return debug_block_tile;
|
||||
return group_scores;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -286,9 +286,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
// reference_topk(debug_host_input, value_ref, index_ref, topk);
|
||||
// debug_ref = reference_basic_gemm<ADataType, ADataType, AccDataType>(a_host, b_host);
|
||||
// debug_ref = reference_basic_gemm_softmax<ADataType, ADataType, AccDataType>(a_host, b_host);
|
||||
reference_basic_gemm_softmax_grouped_topk<ADataType, ADataType, AccDataType, WeightType, IndexType>(
|
||||
a_host, b_host, value_ref, index_ref, topk);
|
||||
debug_ref = reference_basic_gemm_softmax<ADataType, ADataType, AccDataType>(a_host, b_host);
|
||||
// reference_basic_gemm_softmax_grouped_topk<ADataType, ADataType, AccDataType, WeightType, IndexType>(
|
||||
// a_host, b_host, value_ref, index_ref, topk);
|
||||
debug_buf.FromDevice(debug_host_dev.mData.data());
|
||||
value_buf.FromDevice(value_host_dev.mData.data());
|
||||
index_buf.FromDevice(index_host_dev.mData.data());
|
||||
@@ -306,14 +306,14 @@ int main(int argc, char* argv[])
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
{
|
||||
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
|
||||
auto s_end =
|
||||
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
|
||||
// auto s_end =
|
||||
// std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(N)};
|
||||
// std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
|
||||
auto s_end =
|
||||
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(N)};
|
||||
auto s_debug_host = debug_host_dev.slice(s_begin, s_end);
|
||||
// auto s_debug_ref = debug_ref.slice(s_begin, s_end);
|
||||
auto s_debug_ref = debug_ref.slice(s_begin, s_end);
|
||||
// auto s_debug_ref = value_ref.slice(s_begin, s_end);
|
||||
auto s_debug_ref = index_ref.slice(s_begin, s_end);
|
||||
// auto s_debug_ref = index_ref.slice(s_begin, s_end);
|
||||
rtn &= ck_tile::check_err(s_debug_host,
|
||||
s_debug_ref,
|
||||
std::string("[") + std::to_string(i_t) +
|
||||
|
||||
@@ -137,6 +137,7 @@ struct GridGemm
|
||||
|
||||
// block_gemm_pipeline(a_block_window, b_block_window, debug_block_tile, value_block_tile, index_block_tile, K / kKPerBlock, p_smem_char);
|
||||
const auto debug_block_tile = block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
// block_gemm_pipeline(a_block_window, b_block_window, debug_block_tile, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast DataType and apply CElementFunction
|
||||
const auto debug_cast_block_tile = tile_elementwise_in(
|
||||
|
||||
Reference in New Issue
Block a user