remove wrong code in store_raw()

This commit is contained in:
carlushuang
2024-03-13 14:30:55 +00:00
parent 8103048b99
commit a59e655eb2
4 changed files with 22 additions and 29 deletions

View File

@@ -144,3 +144,7 @@
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif

View File

@@ -156,6 +156,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
template <typename OutDataType, typename InTensor>
@@ -192,18 +193,18 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
} o_bulk;
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
// static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){
// o_bulk.data[ib.value] =
// static_cast<o_type>(in_dstr_tensors.get_thread_buffer().template
// get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
// });
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
o_bulk.data[ib.value] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer()
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
});
// TODO: fixme, should use above!
static_assert(sizeof(i_type) / sizeof(o_type) == 2);
o_bulk.data[0] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
o_bulk.data[1] = static_cast<o_type>(
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
// o_bulk.data[0] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
// o_bulk.data[1] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
});
@@ -217,6 +218,7 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
return out_dstr_tensor;
}
#endif
} // namespace impl
template <typename DstType, typename SrcTensor>
@@ -229,10 +231,12 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
{
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
}
#endif
else
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
}

View File

@@ -534,9 +534,7 @@ struct tile_window_with_static_distribution
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
// using vector_t = typename Traits::vector_t;
using vector_t = thread_buffer<DataType, Traits::ScalarPerVector>;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
@@ -554,10 +552,7 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// TODO: below code may result in spill(?)
#if 0
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
@@ -572,22 +567,11 @@ struct tile_window_with_static_distribution
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
#else
(void)tile_dstr;
(void)idx_ys_start;
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord,
dstr_tensor.get_thread_buffer().template get_as<vector_t>(
number<iCoord * NumAccessPerCoord + iCoordAccess>{}));
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{

View File

@@ -483,8 +483,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));