mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Clean-up code
This commit is contained in:
@@ -332,15 +332,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
const index_t new_seqlen_q =
|
||||
integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * FmhaPipeline::kM0;
|
||||
const index_t new_hdim_v =
|
||||
integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) * FmhaPipeline::kN1;
|
||||
const index_t padded_max_seqlen_q =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
|
||||
const index_t padded_hdim_v =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
|
||||
|
||||
return transform_tensor_view(
|
||||
o_acc_dram_view,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, new_seqlen_q)),
|
||||
make_pass_through_transform(new_hdim_v)),
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
|
||||
make_pass_through_transform(padded_hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}();
|
||||
@@ -399,6 +399,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
else
|
||||
@@ -408,6 +409,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
lse_dram_window,
|
||||
smem_ptr,
|
||||
kargs.num_splits,
|
||||
kargs.seqlen_q,
|
||||
kargs.max_seqlen_q);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -75,11 +75,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
return (kM0 * kMaxSplits * sizeof(LSEDataType));
|
||||
}
|
||||
|
||||
#define MARKER(msg) \
|
||||
__builtin_amdgcn_sched_barrier(0); \
|
||||
asm volatile("; [POYENC] " msg ::); \
|
||||
__builtin_amdgcn_sched_barrier(0)
|
||||
|
||||
template <typename LSEaccDramBlockWindowTmp,
|
||||
typename OaccDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
@@ -93,6 +88,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
LSEDataType* lse_acc_lds_ptr =
|
||||
@@ -129,7 +125,14 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
|
||||
if(row < num_splits && col < real_seqlen_q)
|
||||
{
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -138,7 +141,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
|
||||
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
|
||||
|
||||
// copy LDS to lse_accum
|
||||
// copy LDS to lse_accum (transpose)
|
||||
{
|
||||
using DataType = LSEDataType;
|
||||
using StaticTileDistribution = decltype(lse_accum_dist);
|
||||
@@ -162,14 +165,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
{
|
||||
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_accum[%2d,%2d]: %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
lse_accum(distributed_indices));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -204,16 +199,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<LSEDataType>(
|
||||
@@ -240,15 +227,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
|
||||
get_validated_m(lse_max(i_idx)));
|
||||
}
|
||||
#endif
|
||||
#if 0
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("[POYENC][DEVICE] p_compute[%2d,%2d]: %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
p_compute(i_j_idx));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
});
|
||||
@@ -312,21 +290,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_ACCUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
for(index_t row = 0; row < kM0; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_accum[%d] = ", row);
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// write lse scales into LDS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
@@ -365,7 +328,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
static_assert(kBlockSize == 256);
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
@@ -378,7 +340,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
|
||||
clear_tile(o_acc);
|
||||
|
||||
// [POYENC] added
|
||||
const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
auto o_tile = load_tile(o_acc_dram_window);
|
||||
@@ -394,14 +357,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
|
||||
"* o_tile(%11.7f)\n",
|
||||
row,
|
||||
col,
|
||||
o_acc(distributed_indices),
|
||||
lse_scale,
|
||||
o_tile(distributed_indices));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
@@ -417,6 +391,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
void* smem_ptr,
|
||||
index_t num_splits,
|
||||
index_t real_seqlen_q,
|
||||
index_t max_seqlen_q) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
@@ -426,6 +401,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
smem_ptr,
|
||||
num_splits,
|
||||
real_seqlen_q,
|
||||
max_seqlen_q);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user