Fix wrong layout of LSE/LSEacc/Oacc

This commit is contained in:
PoYen, Chen
2024-06-04 23:38:06 +00:00
parent 064afc69d9
commit 18a7223b96
5 changed files with 221 additions and 56 deletions

View File

@@ -33,6 +33,12 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
return os << "]";
}
#if defined(ENABLE_APP_DEBUG_STMTS)
#define APP_DEBUG_STMTS if(true)
#else
#define APP_DEBUG_STMTS if(false)
#endif
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -382,9 +388,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
APP_DEBUG_STMTS {
std::cout << "lse_acc_host shape: " << num_splits << ", " << batch << ", "
<< nhead << ", " << max_seqlen_q << std::endl;
}
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 5>{num_splits,
shape_batch,
@@ -392,7 +401,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
APP_DEBUG_STMTS{
std::cout << "o_acc_host shape: " << num_splits << ", " << shape_batch << ", "
<< nhead << ", " << shape_seqlen_q << ", " << hdim_v << std::endl;
}
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
@@ -660,6 +672,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
lse_acc_buf.FromDevice(lse_acc_host.data());
randval_buf.FromDevice(randval_host.data());
float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t =
@@ -668,11 +681,121 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool pass = true;
APP_DEBUG_STMTS
{
printf("\n");
printf("[POYENC][HOST] lse shape: %d, %d, %d, %d\n",
num_splits,
shape_batch,
nhead,
shape_seqlen_q);
}
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
APP_DEBUG_STMTS { printf("[POYENC][HOST] wb: %d\n", wb); }
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
APP_DEBUG_STMTS
{
// lse_acc_host shape: num_splits, shape_batch, nhead, shape_seqlen_q
for(int i_split = 0; i_split < num_splits; ++i_split)
{
printf("[POYENC][HOST] i_split: %d\n", i_split);
printf("[POYENC][HOST] lse_acc_host(%2d,%2d, 0) = ", i_split, wb);
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_acc_host(i_split, wb, 1, row));
}
printf("\n");
}
}
APP_DEBUG_STMTS
{
ck_tile::HostTensor<LSEDataType> lse_max({real_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
lse_max(row) = -ck_tile::numeric<LSEDataType>::infinity();
for(int i_split = 0; i_split < num_splits; ++i_split)
{
if(lse_max(row) < lse_acc_host(i_split, wb, 1, row))
{
lse_max(row) = lse_acc_host(i_split, wb, 1, row);
}
}
}
printf("[POYENC][HOST] lse_max: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_max(row));
}
printf("\n");
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
return raw_m == -ck_tile::numeric<LSEDataType>::infinity()
? ck_tile::type_convert<LSEDataType>(0.f)
: raw_m;
};
ck_tile::HostTensor<LSEDataType> lse_sum({shape_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
lse_sum(row) = 0;
for(int i_split = 0; i_split < num_splits; ++i_split)
{
lse_sum(row) += ck_tile::exp(lse_acc_host(i_split, wb, 1, row) -
get_validated_m(lse_max(row)));
}
}
printf("[POYENC][HOST] lse_sum: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_sum(row));
}
printf("\n");
ck_tile::HostTensor<LSEDataType> lse_logsum({real_seqlen_q});
for(int row = 0; row < real_seqlen_q; ++row)
{
if(lse_sum(row) == 0.f || lse_sum(row) != lse_sum(row))
{
lse_logsum(row) = ck_tile::numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(row) = ck_tile::log(lse_sum(row)) + get_validated_m(lse_max(row));
}
}
for(int row = 0; row < real_seqlen_q; ++row)
{
if(lse_logsum(row) == ck_tile::numeric<LSEDataType>::infinity())
{
lse_logsum(row) = -ck_tile::numeric<LSEDataType>::infinity();
}
}
// lse_host shape: [batch, nhead, max_seqlen_q]
printf("[POYENC][DEVICE] lse_host: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_host(wb, 1, row));
}
printf("\n");
printf("[POYENC][HOST] lse_logsum: ");
for(int row = 0; row < real_seqlen_q; ++row)
{
printf("%11.7f", lse_logsum(row));
}
printf("\n");
}
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
@@ -867,8 +990,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = ck_tile::check_err(
bool cur_pass = true;
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
// if (cur_pass) std::cout << "LSE pass" << std::endl;
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
}
}
cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
// if (cur_pass) std::cout << "OUT pass" << std::endl;
pass &= cur_pass;
if(!cur_pass)
{
@@ -880,32 +1028,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
break;
}
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
bool lse_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= lse_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
break;
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;

View File

@@ -103,6 +103,7 @@ struct FmhaFwdSplitKVCombineKernel
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct Fp8StaticQuantKargs
@@ -110,14 +111,9 @@ struct FmhaFwdSplitKVCombineKernel
float scale_o;
};
struct BatchModeLSEKargs : CommonLSEKargs
{
ck_tile::index_t batch_stride_lse = 0;
};
struct BatchModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, BatchModeLSEKargs, EmptyKargs<0>>,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
@@ -226,10 +222,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs;
}
__host__ static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
@@ -245,9 +243,11 @@ struct FmhaFwdSplitKVCombineKernel
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
@@ -259,11 +259,12 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.hdim_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
batch_offset_lse = static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
}
batch_offset_o = query_start * kargs.row_stride_o;
@@ -432,7 +433,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, 0});
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}

View File

@@ -16,19 +16,36 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN0;
// constexpr static ck_tile::index_t kBlockM = kN1 % 128 == 0 ? 4 : (kN1 % 64 == 0 ? 8 : 16);
__host__ static constexpr auto
GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t max_seqlen_q)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0), nhead, batch_size);
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
const index_t i_tile_m = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
return ck_tile::make_tuple(i_tile_m, i_nhead, i_batch);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};

View File

@@ -506,8 +506,10 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.hdim_v;
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * (kargs.nhead * kargs.max_seqlen_q);
batch_offset_o_acc = static_cast<long_index_t>(i_batch) *
(kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v);
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;

View File

@@ -103,8 +103,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0]
#if !defined(TID)
#define TID 0
#endif
#if defined(ENABLE_DEBUG_STMTS)
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0 && threadIdx.x == TID)
#else
#define DEBUG_STMTS if(false)
#endif
@@ -169,6 +173,22 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
#if defined(PRINT_LSE_ACCUM)
DEBUG_STMTS
{
printf("\n");
for(index_t row = 0; row < num_splits; ++row)
{
printf("[POYENC][DEVICE] lse_acc[%2d] = ", row);
for(index_t col = 0; col < real_seqlen_q; ++col)
{
printf("%11.7f", lse_acc_lds_ptr[row + col * kMaxSplits]);
}
printf("\n");
}
}
#endif
// calculate row_max of lse_accum
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
@@ -369,12 +389,14 @@ struct BlockFmhaFwdSplitKVCombinePipeline
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
#if 0
DEBUG_STMTS
{
const auto col = x_indices.at(number<1>{});
printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) "
"* o_tile(%11.7f)\n",
row,
@@ -383,6 +405,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_scale,
o_tile(distributed_indices));
}
#endif
});
});
}