mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
remove wrong code in store_raw()
This commit is contained in:
@@ -144,3 +144,7 @@
|
||||
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
@@ -156,6 +156,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
@@ -192,18 +193,18 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
} o_bulk;
|
||||
|
||||
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
|
||||
// static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){
|
||||
// o_bulk.data[ib.value] =
|
||||
// static_cast<o_type>(in_dstr_tensors.get_thread_buffer().template
|
||||
// get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
// });
|
||||
static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
|
||||
o_bulk.data[ib.value] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer()
|
||||
.template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
|
||||
});
|
||||
|
||||
// TODO: fixme, should use above!
|
||||
static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
o_bulk.data[0] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
o_bulk.data[1] = static_cast<o_type>(
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
|
||||
// o_bulk.data[0] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
|
||||
// o_bulk.data[1] = static_cast<o_type>(
|
||||
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
|
||||
});
|
||||
@@ -217,6 +218,7 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
|
||||
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
#endif
|
||||
} // namespace impl
|
||||
|
||||
template <typename DstType, typename SrcTensor>
|
||||
@@ -229,10 +231,12 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
{
|
||||
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
{
|
||||
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
|
||||
}
|
||||
|
||||
@@ -534,9 +534,7 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
// using vector_t = typename Traits::vector_t;
|
||||
using vector_t = thread_buffer<DataType, Traits::ScalarPerVector>;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
@@ -554,10 +552,7 @@ struct tile_window_with_static_distribution
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// TODO: below code may result in spill(?)
|
||||
#if 0
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
@@ -572,22 +567,11 @@ struct tile_window_with_static_distribution
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
#else
|
||||
(void)tile_dstr;
|
||||
(void)idx_ys_start;
|
||||
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord,
|
||||
dstr_tensor.get_thread_buffer().template get_as<vector_t>(
|
||||
number<iCoord * NumAccessPerCoord + iCoordAccess>{}));
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
|
||||
@@ -483,8 +483,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user