diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 86d14dc375..0ef51a5552 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -144,3 +144,7 @@ #ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #endif + +#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST +#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 +#endif diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 4d631f90ac..90ad94b12b 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -156,6 +156,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) #endif } +#if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) template @@ -192,18 +193,18 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) } o_bulk; // TODO: should use below function, but somehow will result in spill (same as c-forloop) - // static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib){ - // o_bulk.data[ib.value] = - // static_cast(in_dstr_tensors.get_thread_buffer().template - // get_as()[number{}]); - // }); + static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) { + o_bulk.data[ib.value] = static_cast( + in_dstr_tensors.get_thread_buffer() + .template get_as()[number{}]); + }); // TODO: fixme, should use above! - static_assert(sizeof(i_type) / sizeof(o_type) == 2); - o_bulk.data[0] = static_cast( - in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 0>{}]); - o_bulk.data[1] = static_cast( - in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 1>{}]); + // static_assert(sizeof(i_type) / sizeof(o_type) == 2); + // o_bulk.data[0] = static_cast( + // in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 0>{}]); + // o_bulk.data[1] = static_cast( + // in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 1>{}]); out_dstr_tensor.get_thread_buffer().template set_as(i, o_bulk.bulk); }); @@ -217,6 +218,7 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) return out_dstr_tensor; } +#endif } // namespace impl template @@ -229,10 +231,12 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) { return impl::cast_tile_pk_fp8x4(src_tensor); } +#if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) { return impl::cast_tile_opt_subdword(src_tensor); } +#endif else return tile_elementwise_in(type_convert, src_tensor); } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index cd96671456..0eaddb9947 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -534,9 +534,7 @@ struct tile_window_with_static_distribution { using Traits = load_store_traits; - // using vector_type_t = typename Traits::vector_type_t; - // using vector_t = typename Traits::vector_t; - using vector_t = thread_buffer; + using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; @@ -554,10 +552,7 @@ struct tile_window_with_static_distribution // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - // TODO: below code may result in spill(?) -#if 0 // read from distributed tensor - // vector_type_t vec; vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( @@ -572,22 +567,11 @@ struct tile_window_with_static_distribution dstr_tensor.get_thread_buffer().template at(); }); - // const vector_t vec_value = vec.template get_as().template at<0>(); - // write into bottom tensor get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, vec_value); -#else - (void)tile_dstr; - (void)idx_ys_start; - get_bottom_tensor_view() - .template set_vectorized_elements_raw( - bottom_tensor_thread_coord, - dstr_tensor.get_thread_buffer().template get_as( - number{})); -#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 58510a5411..712c0ca2c9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -483,8 +483,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{}));