Clean-up code

This commit is contained in:
PoYen, Chen
2024-06-03 08:52:21 +00:00
parent eac0f3cc47
commit 5a6b8d8606
2 changed files with 36 additions and 58 deletions

View File

@@ -332,15 +332,15 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t new_seqlen_q =
integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0;
const index_t new_hdim_v =
integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1;
const index_t padded_max_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)),
make_pass_through_transform(new_hdim_v)),
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
@@ -399,6 +399,7 @@ struct FmhaFwdSplitKVCombineKernel
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
smem_ptr,
kargs.num_splits,
kargs.seqlen_q,
kargs.max_seqlen_q);
}
else
@@ -408,6 +409,7 @@ struct FmhaFwdSplitKVCombineKernel
lse_dram_window,
smem_ptr,
kargs.num_splits,
kargs.seqlen_q,
kargs.max_seqlen_q);
}
}();

View File

@@ -75,11 +75,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
return (kM0 * kMaxSplits * sizeof(LSEDataType));
}
#define MARKER(msg) \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("; [POYENC] " msg ::); \
__builtin_amdgcn_sched_barrier(0)
template <typename LSEaccDramBlockWindowTmp,
typename OaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
@@ -93,6 +88,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccElementFunction& o_acc_element_func,
void* smem_ptr,
index_t num_splits,
index_t real_seqlen_q,
index_t max_seqlen_q) const
{
LSEDataType* lse_acc_lds_ptr =
@@ -129,7 +125,14 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
if(row < num_splits && col < real_seqlen_q)
{
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
}
else
{
lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric<LSEDataType>::infinity();
}
});
});
}
@@ -138,7 +141,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
// copy LDS to lse_accum
// copy LDS to lse_accum (transpose)
{
using DataType = LSEDataType;
using StaticTileDistribution = decltype(lse_accum_dist);
@@ -162,14 +165,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
}
DEBUG_STMTS
{
printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n",
row,
col,
lse_accum(distributed_indices));
}
});
});
}
@@ -204,16 +199,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
auto p_compute = make_static_distributed_tensor<LSEDataType>(
@@ -240,15 +227,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
get_validated_m(lse_max(i_idx)));
}
#endif
#if 0
DEBUG_STMTS
{
printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n",
row,
col,
p_compute(i_j_idx));
}
#endif
});
});
@@ -312,21 +290,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
#if defined(PRINT_LSE_ACCUM)
DEBUG_STMTS
{
for(index_t row = 0; row < kM0; ++row)
{
printf("[POYENC][DEVICE] lse_accum[%d] = ", row);
for(index_t col = 0; col < num_splits; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
}
printf("\n");
}
}
#endif
// write lse scales into LDS
{
constexpr auto out_spans =
@@ -365,7 +328,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
if constexpr(kStoreLSE)
{
static_assert(kBlockSize == 256);
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
@@ -378,7 +340,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
clear_tile(o_acc);
// [POYENC] added
const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
auto o_tile = load_tile(o_acc_dram_window);
@@ -394,14 +357,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
DEBUG_STMTS
{
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
"* o_tile(%11.7f)\n",
row,
col,
o_acc(distributed_indices),
lse_scale,
o_tile(distributed_indices));
}
});
});
}
move_tile_window(o_acc_dram_window, {max_seqlen_q, 0});
move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
@@ -417,6 +391,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
LSEDramBlockWindow& lse_dram_block_window,
void* smem_ptr,
index_t num_splits,
index_t real_seqlen_q,
index_t max_seqlen_q) const
{
return operator()(lse_acc_dram_block_window,
@@ -426,6 +401,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
smem_ptr,
num_splits,
real_seqlen_q,
max_seqlen_q);
}
};