diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 1248c1c957..9b4dc30eb2 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -39,7 +39,9 @@ template + bool TilesPacked_ = false, + typename CComputeDataType_ = AccDataType_, + typename CShuffleDataType_ = ODataType_> struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; @@ -52,6 +54,8 @@ struct CShuffleEpilogueProblem using CDElementwise = remove_cvref_t; using AComputeDataType = remove_cvref_t; using BComputeDataType = remove_cvref_t; + using CComputeDataType = remove_cvref_t; + using CShuffleDataType = remove_cvref_t; static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; @@ -84,6 +88,8 @@ struct CShuffleEpilogue using DsLayout = remove_cvref_t; using AComputeDataType = remove_cvref_t; using BComputeDataType = remove_cvref_t; + using CComputeDataType = remove_cvref_t; + using CShuffleDataType = remove_cvref_t; static constexpr bool ADataTypeIsTuple = is_detected::value; static constexpr bool BDataTypeIsTuple = is_detected::value; @@ -236,7 +242,7 @@ struct CShuffleEpilogue constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile; constexpr auto shuffle_tile = - m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer + m_val * n_val * sizeof(CShuffleDataType) > get_smem_capacity() || DoubleSmemBuffer ? std::make_tuple(1, 1) : std::make_tuple(m_shuffle_tile, n_shuffle_tile); @@ -317,7 +323,7 @@ struct CShuffleEpilogue template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { - constexpr auto DataTypeSize = sizeof(ODataType); + constexpr auto DataTypeSize = sizeof(CShuffleDataType); constexpr index_t VectorLen = GetVectorSizeC(); constexpr index_t banks = get_n_lds_banks(); @@ -550,7 +556,7 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); - return lds_block_desc.get_element_space_size() * sizeof(ODataType); + return lds_block_desc.get_element_space_size() * sizeof(CShuffleDataType); } /// Number of block_sync_lds() calls in operator(). @@ -581,13 +587,14 @@ struct CShuffleEpilogue } template - CK_TILE_DEVICE void + CK_TILE_DEVICE auto scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window) { // Check if scales are EmptyScale first (no scaling needed) if constexpr(std::is_same_v && std::is_same_v) { // No scaling needed - this is a no-op + return lds_tile; } // Check if scales are scalar AccDataType else if constexpr(std::is_same_v && @@ -598,6 +605,7 @@ struct CShuffleEpilogue const AccDataType scale_n = scale_n_window; tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; }, lds_tile); + return lds_tile; } // Otherwise, assume they are tile windows that can be loaded else @@ -606,9 +614,15 @@ struct CShuffleEpilogue const auto scale_m_tile = load_tile(scale_m_window); const auto scale_n_tile = load_tile(scale_n_window); - // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n - tile_elementwise_inout( - element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile); + auto scaled_lds_tile = + make_static_distributed_tensor(lds_tile.get_tile_distribution()); + + // Compute element-wise product i.e. scaled_lds_tile = lds_tile * scale_m * scale_n + tile_elementwise_inout(element_wise::MultiDMultiply{}, + scaled_lds_tile, + lds_tile, + scale_m_tile, + scale_n_tile); // Move scale windows constexpr index_t num_access = SFC::get_num_of_access(); @@ -619,6 +633,8 @@ struct CShuffleEpilogue move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); } + + return scaled_lds_tile; } } @@ -641,14 +657,6 @@ struct CShuffleEpilogue c_warp_y_lengths)); } - template - CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window) - { - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - - store_tile(in_lds_window, c_warptile_in_tensor_casted); - } - template CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor) { @@ -734,7 +742,7 @@ struct CShuffleEpilogue constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( - static_cast(p_smem), lds_block_desc); + static_cast(p_smem), lds_block_desc); auto in_lds_window = make_tile_window( o_lds_block, @@ -810,18 +818,16 @@ struct CShuffleEpilogue block_sync_lds(); slice_acc_tile(o_acc_tile, lds_tile); - if constexpr(has_scales) - { + const auto scaled_lds_tile = scale_tile(lds_tile, scale_m_window, scale_n_window); - } - cast_lds_tile(lds_tile, in_lds_window); + store_tile(in_lds_window, cast_tile(scaled_lds_tile)); block_sync_lds(); auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); apply_d_tensors(d_dram_windows, c_out_tensor); - store_to_dram(out_dram_window, c_out_tensor); + store_to_dram(out_dram_window, cast_tile(c_out_tensor)); move_windows(out_dram_window, d_dram_windows); }); }