Remove more debug code in combine pipeline

This commit is contained in:
PoYen, Chen
2024-06-11 18:36:23 +00:00
parent 4f8cef36bc
commit f968a7e442

View File

@@ -99,15 +99,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
#if !defined(TID)
#define TID 0
#endif
#if defined(ENABLE_DEBUG_STMTS)
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
#else
#define DEBUG_STMTS if(false)
#endif
// copy lse_acc to LDS
{
using DataType = LSEDataType;
@@ -169,22 +160,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
#if defined(PRINT_LSE_ACCUM)
DEBUG_STMTS
{
printf("(%2d, %2d)\n", num_splits, kM0);
for(index_t row = 0; row < num_splits; ++row)
{
printf("[POYENC][DEVICE] lse_acc[%2d] = ", row);
for(index_t col = 0; col < kM0; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[row + col * kMaxSplits]);
}
printf("\n");
}
}
#endif
// calculate row_max of lse_accum
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
@@ -193,25 +168,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
#if defined(PRINT_LSE_MAX)
DEBUG_STMTS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_max.get_tile_distribution())>::
get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_max.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf(
"[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices));
});
}
#endif
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
@@ -234,16 +190,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
#if 0
// from dist tensor
p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
#else
if (col < num_splits) {
if(col < num_splits)
{
// from shared memory
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
get_validated_m(lse_max(i_idx)));
}
#endif
});
});
}
@@ -253,25 +205,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
#if defined(PRINT_LSE_SUM)
DEBUG_STMTS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_sum.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf(
"[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices));
});
}
#endif
decltype(lse_max) lse_logsum;
{
constexpr auto out_spans = static_distributed_tensor<
@@ -290,19 +223,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
get_validated_m(lse_max(distributed_indices));
}
#if defined(PRINT_LSE_LOGSUM)
DEBUG_STMTS
{
const auto x_indices = get_x_indices_from_distributed_indices(
lse_logsum.get_tile_distribution(), distributed_indices);
const auto row = x_indices.at(number<0>{});
printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n",
row,
lse_logsum(distributed_indices));
}
#endif
});
}
@@ -327,21 +247,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
block_sync_lds();
#if defined(PRINT_LSE_SCALE)
DEBUG_STMTS
{
for(index_t row = 0; row < 32; ++row)
{
printf("[POYENC][DEVICE] lse_scale[%2d] = ", row);
for(index_t col = 0; col < num_splits; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
}
printf("\n");
}
}
#endif
if constexpr(kStoreLSE)
{
constexpr auto out_spans = static_distributed_tensor<
@@ -388,19 +293,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
#if 0
DEBUG_STMTS
{
const auto col = x_indices.at(number<1>{});
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
"* o_tile(%11.7f)\n",
row,
col,
o_acc(distributed_indices),
lse_scale,
o_tile(distributed_indices));
}
#endif
});
});
}