mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Reduce redundant space in bias tensor (#2024)
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
[ROCm/composable_kernel commit: 8a20b62e91]
This commit is contained in:
@@ -620,7 +620,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
|
||||
@@ -884,7 +884,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else
|
||||
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = (hdim_v);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
@@ -909,7 +909,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
(i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
@@ -925,7 +925,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
@@ -1381,9 +1381,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); });
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
|
||||
Reference in New Issue
Block a user