mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 16:26:10 +00:00
Fix wrong layout of LSE/LSEacc/Oacc
This commit is contained in:
@@ -33,6 +33,12 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
#if defined(ENABLE_APP_DEBUG_STMTS)
|
||||
#define APP_DEBUG_STMTS if(true)
|
||||
#else
|
||||
#define APP_DEBUG_STMTS if(false)
|
||||
#endif
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -382,9 +388,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
APP_DEBUG_STMTS {
|
||||
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", "
|
||||
<< nhead << ", " << max_seqlen_q << std::endl;
|
||||
}
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 5>{num_splits,
|
||||
shape_batch,
|
||||
@@ -392,7 +401,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
shape_seqlen_q,
|
||||
hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
APP_DEBUG_STMTS{
|
||||
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", "
|
||||
<< nhead << ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
|
||||
}
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
@@ -660,6 +672,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
lse_acc_buf.FromDevice(lse_acc_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
@@ -668,11 +681,121 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
bool pass = true;
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
printf("\n");
|
||||
printf("[POYENC][HOST] lse shape: %d, %d, %d, %d\n",
|
||||
num_splits,
|
||||
shape_batch,
|
||||
nhead,
|
||||
shape_seqlen_q);
|
||||
}
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
APP_DEBUG_STMTS { printf("[POYENC][HOST] wb: %d\n", wb); }
|
||||
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
// lse_acc_host shape: num_splits, shape_batch, nhead, shape_seqlen_q
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
printf("[POYENC][HOST] i_split: %d\n", i_split);
|
||||
printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0) = ", i_split, wb);
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_acc_host(i_split, wb, 1, row));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
APP_DEBUG_STMTS
|
||||
{
|
||||
ck_tile::HostTensor<LSEDataType> lse_max({real_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
lse_max(row) = -ck_tile::numeric<LSEDataType>::infinity();
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
if(lse_max(row) < lse_acc_host(i_split, wb, 1, row))
|
||||
{
|
||||
lse_max(row) = lse_acc_host(i_split, wb, 1, row);
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("[POYENC][HOST] lse_max: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_max(row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
return raw_m == -ck_tile::numeric<LSEDataType>::infinity()
|
||||
? ck_tile::type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_sum({shape_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
lse_sum(row) = 0;
|
||||
for(int i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
lse_sum(row) += ck_tile::exp(lse_acc_host(i_split, wb, 1, row) -
|
||||
get_validated_m(lse_max(row)));
|
||||
}
|
||||
}
|
||||
printf("[POYENC][HOST] lse_sum: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_sum(row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_logsum({real_seqlen_q});
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
if(lse_sum(row) == 0.f || lse_sum(row) != lse_sum(row))
|
||||
{
|
||||
lse_logsum(row) = ck_tile::numeric<LSEDataType>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_logsum(row) = ck_tile::log(lse_sum(row)) + get_validated_m(lse_max(row));
|
||||
}
|
||||
}
|
||||
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
if(lse_logsum(row) == ck_tile::numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(row) = -ck_tile::numeric<LSEDataType>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
// lse_host shape: [batch, nhead, max_seqlen_q]
|
||||
printf("[POYENC][DEVICE] lse_host: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_host(wb, 1, row));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("[POYENC][HOST] lse_logsum: ");
|
||||
for(int row = 0; row < real_seqlen_q; ++row)
|
||||
{
|
||||
printf("%11.7f", lse_logsum(row));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
@@ -867,8 +990,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
bool cur_pass = true;
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
// if (cur_pass) std::cout << "LSE pass" << std::endl;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
}
|
||||
}
|
||||
cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
// if (cur_pass) std::cout << "OUT pass" << std::endl;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
@@ -880,32 +1028,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= lse_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
|
||||
@@ -103,6 +103,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
ck_tile::index_t nhead_stride_lse = 0;
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct Fp8StaticQuantKargs
|
||||
@@ -110,14 +111,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
float scale_o;
|
||||
};
|
||||
|
||||
struct BatchModeLSEKargs : CommonLSEKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct BatchModeKargs
|
||||
: CommonKargs,
|
||||
std::conditional_t<kStoreLSE, BatchModeLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
@@ -226,10 +222,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -245,9 +243,11 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
@@ -259,11 +259,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.hdim_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
}
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
@@ -432,7 +433,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, 0});
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
}
|
||||
|
||||
@@ -16,19 +16,36 @@ struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN0;
|
||||
// constexpr static ck_tile::index_t kBlockM = kN1 % 128 == 0 ? 4 : (kN1 % 64 == 0 ? 8 : 16);
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), nhead, batch_size);
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
const index_t i_tile_m = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -506,8 +506,10 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.hdim_v;
|
||||
batch_offset_lse_acc =
|
||||
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
|
||||
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
|
||||
@@ -103,8 +103,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
|
||||
|
||||
#if !defined(TID)
|
||||
#define TID 0
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_DEBUG_STMTS)
|
||||
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
|
||||
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0 && threadIdx.x == TID)
|
||||
#else
|
||||
#define DEBUG_STMTS if(false)
|
||||
#endif
|
||||
@@ -169,6 +173,22 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(PRINT_LSE_ACCUM)
|
||||
DEBUG_STMTS
|
||||
{
|
||||
printf("\n");
|
||||
for(index_t row = 0; row < num_splits; ++row)
|
||||
{
|
||||
printf("[POYENC][DEVICE] lse_acc[%2d] = ", row);
|
||||
for(index_t col = 0; col < real_seqlen_q; ++col)
|
||||
{
|
||||
printf("%11.7f", lse_acc_lds_ptr[row + col * kMaxSplits]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// calculate row_max of lse_accum
|
||||
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
@@ -369,12 +389,14 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
#if 0
|
||||
DEBUG_STMTS
|
||||
{
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
|
||||
"* o_tile(%11.7f)\n",
|
||||
row,
|
||||
@@ -383,6 +405,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_scale,
|
||||
o_tile(distributed_indices));
|
||||
}
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user