Fix compilation errors

This commit is contained in:
PoYen, Chen
2024-07-10 10:53:58 +00:00
parent 03b6d99be0
commit 8c733fb3be
3 changed files with 13 additions and 10 deletions

View File

@@ -30,12 +30,12 @@ generate_cos_sin(ck_tile::index_t seqlen,
std::generate(begin(angle), end(angle), std::bind(generator, std::ref(random_engine)));
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), [](float origin_value) {
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), [](float origin_value) {
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
@@ -59,7 +59,7 @@ index_cos_sin(const ck_tile::HostTensor<DataType>& cos,
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(seqlen_offset + seqlen <= cos.get_length(0));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);

View File

@@ -15,7 +15,7 @@ namespace detail {
}
template <typename DataType, typename ComputeDataType>
template <typename ComputeDataType, typename DataType>
CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>& input_bhsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
@@ -27,7 +27,7 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(rotary_dim <= input_bhsd.get_length(3));
assert(static_cast<std::size_t>(rotary_dim) <= input_bhsd.get_length(3));
output_bhsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[3];
@@ -39,10 +39,10 @@ CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor<DataType>
const index_t i_s = i[2];
const ComputeDataType cos =
(interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim));
const ComputeDataType sin =
(interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim));
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % rotary_dim));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % rotary_dim));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
const index_t i_h = i[1];

View File

@@ -20,9 +20,12 @@ struct FmhaFwdAppendKVTilePartitioner
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_knew,
ck_tile::index_t /*hdim_v*/)
ck_tile::index_t hdim_v)
{
assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1);
#ifdef NDEBUG
ignore = hdim_v;
#endif
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size);