mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Remove more debug code in combine pipeline
This commit is contained in:
@@ -99,15 +99,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
|
||||
|
||||
#if !defined(TID)
|
||||
#define TID 0
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_DEBUG_STMTS)
|
||||
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
|
||||
#else
|
||||
#define DEBUG_STMTS if(false)
|
||||
#endif
|
||||
// copy lse_acc to LDS
|
||||
{
|
||||
using DataType = LSEDataType;
|
||||
@@ -169,22 +160,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_ACCUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("(%2d, %2d)\n", num_splits, kM0);
|
||||
for(index_t row = 0; row < num_splits; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_acc[%2d] = ", row);
|
||||
for(index_t col = 0; col < kM0; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[row + col * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// calculate row_max of lse_accum
|
||||
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
@@ -193,25 +168,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
|
||||
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
|
||||
|
||||
#if defined(PRINT_LSE_MAX)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_max.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_max.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
printf(
|
||||
"[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
@@ -234,16 +190,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
#if 0
|
||||
// from dist tensor
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
|
||||
#else
|
||||
if (col < num_splits) {
|
||||
if(col < num_splits)
|
||||
{
|
||||
// from shared memory
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
|
||||
get_validated_m(lse_max(i_idx)));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -253,25 +205,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
|
||||
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
|
||||
|
||||
#if defined(PRINT_LSE_SUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_sum.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
printf(
|
||||
"[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices));
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
decltype(lse_max) lse_logsum;
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
@@ -290,19 +223,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
|
||||
get_validated_m(lse_max(distributed_indices));
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_LOGSUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_logsum.get_tile_distribution(), distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n",
|
||||
row,
|
||||
lse_logsum(distributed_indices));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
@@ -327,21 +247,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
#if defined(PRINT_LSE_SCALE)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
for(index_t row = 0; row < 32; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_scale[%2d] = ", row);
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
@@ -388,19 +293,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
#if 0
|
||||
DEBUG_STMTS
|
||||
{
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
|
||||
"* o_tile(%11.7f)\n",
|
||||
row,
|
||||
col,
|
||||
o_acc(distributed_indices),
|
||||
lse_scale,
|
||||
o_tile(distributed_indices));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user